{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "tutorial.ipynb",
      "provenance": [],
      "collapsed_sections": [],
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    },
    "accelerator": "GPU"
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "JndnmDMp66FL"
      },
      "source": [
        "##### Copyright 2021 Google LLC.\n",
        "\n",
        "Licensed under the Apache License, Version 2.0 (the \"License\");"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "hMqWDc_m6rUC"
      },
      "source": [
        "# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ],
      "execution_count": 1,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "n8uMAcMvCrmN"
      },
      "source": [
        "# Generalized Gumbel-max causal mechanisms tutorial"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "FDV72jCXCxr0"
      },
      "source": [
        "This notebook explains the APIs of the Gumbel-max causal mechanism implementation, along with those of our two gadgets."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "UNnd2YesSAFr"
      },
      "source": [
        "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google-research/google-research/blob/master/gumbel_max_causal_gadgets/tutorial.ipynb)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Tr2Z0N-qSNz6"
      },
      "source": [
        "## Setting up the environment"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "obGYeV4EUb61"
      },
      "source": [
        "These instructions are designed for running this tutorial using Google Colab; if you are using a different environment, the setup instructions may differ!\n",
        "\n",
        "The first step is to connect the Colab runtime to a GPU. You can use the \"Runtime > Change runtime type\" option in the toolbar above.\n",
        "Next, install necessary dependencies:"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "_j6tXnzFn-s6",
        "outputId": "45959936-cd47-4f92-cec7-53de7a9945da"
      },
      "source": [
        "# Download the codebase\n",
        "!git clone https://github.com/google-research/google-research.git --depth=1"
      ],
      "execution_count": 2,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "fatal: destination path 'google-research' already exists and is not an empty directory.\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "NsacX7KEsp0D"
      },
      "source": [
        "import os\n",
        "os.chdir(\"google-research\")"
      ],
      "execution_count": 3,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "AFH12ED4-5fU",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "b842f2ae-8caa-4d38-e41a-ebb94a940f4a"
      },
      "source": [
        "# Install Python packages\n",
        "!pip install flax optax"
      ],
      "execution_count": 4,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Requirement already satisfied: flax in /usr/local/lib/python3.7/dist-packages (0.3.6)\n",
            "Requirement already satisfied: optax in /usr/local/lib/python3.7/dist-packages (0.0.9)\n",
            "Requirement already satisfied: jax>=0.2.21 in /usr/local/lib/python3.7/dist-packages (from flax) (0.2.21)\n",
            "Requirement already satisfied: matplotlib in /usr/local/lib/python3.7/dist-packages (from flax) (3.2.2)\n",
            "Requirement already satisfied: msgpack in /usr/local/lib/python3.7/dist-packages (from flax) (1.0.2)\n",
            "Requirement already satisfied: numpy>=1.12 in /usr/local/lib/python3.7/dist-packages (from flax) (1.19.5)\n",
            "Requirement already satisfied: scipy>=1.2.1 in /usr/local/lib/python3.7/dist-packages (from jax>=0.2.21->flax) (1.4.1)\n",
            "Requirement already satisfied: absl-py in /usr/local/lib/python3.7/dist-packages (from jax>=0.2.21->flax) (0.12.0)\n",
            "Requirement already satisfied: opt-einsum in /usr/local/lib/python3.7/dist-packages (from jax>=0.2.21->flax) (3.3.0)\n",
            "Requirement already satisfied: chex>=0.0.4 in /usr/local/lib/python3.7/dist-packages (from optax) (0.0.8)\n",
            "Requirement already satisfied: jaxlib>=0.1.37 in /usr/local/lib/python3.7/dist-packages (from optax) (0.1.71+cuda111)\n",
            "Requirement already satisfied: six in /usr/local/lib/python3.7/dist-packages (from absl-py->jax>=0.2.21->flax) (1.15.0)\n",
            "Requirement already satisfied: toolz>=0.9.0 in /usr/local/lib/python3.7/dist-packages (from chex>=0.0.4->optax) (0.11.2)\n",
            "Requirement already satisfied: dm-tree>=0.1.5 in /usr/local/lib/python3.7/dist-packages (from chex>=0.0.4->optax) (0.1.6)\n",
            "Requirement already satisfied: flatbuffers<3.0,>=1.12 in /usr/local/lib/python3.7/dist-packages (from jaxlib>=0.1.37->optax) (2.0)\n",
            "Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.7/dist-packages (from matplotlib->flax) (0.11.0)\n",
            "Requirement already satisfied: pyparsing!=2.0.4,!=2.1.2,!=2.1.6,>=2.0.1 in /usr/local/lib/python3.7/dist-packages (from matplotlib->flax) (2.4.7)\n",
            "Requirement already satisfied: python-dateutil>=2.1 in /usr/local/lib/python3.7/dist-packages (from matplotlib->flax) (2.8.2)\n",
            "Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.7/dist-packages (from matplotlib->flax) (1.3.2)\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "l1pPA8JrKmh2"
      },
      "source": [
        "import os\n",
        "os.environ[\"XLA_PYTHON_CLIENT_ALLOCATOR\"] = \"platform\""
      ],
      "execution_count": 6,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "yBXGGg-9VxYT",
        "outputId": "a5955778-14ad-485f-a95c-f0e546102b32"
      },
      "source": [
        "import jax\n",
        "jax.devices()"
      ],
      "execution_count": 7,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "[GpuDevice(id=0, process_index=0)]"
            ]
          },
          "metadata": {},
          "execution_count": 7
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "QWHbsCIrmclj"
      },
      "source": [
        "## Imports and configuration"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Q1D5b0C0Dlb4"
      },
      "source": [
        "import functools\n",
        "import time\n",
        "from typing import *\n",
        "\n",
        "import numpy as np\n",
        "import jax\n",
        "import jax.numpy as jnp\n",
        "import optax\n",
        "import flax\n",
        "import flax.linen as nn\n",
        "\n",
        "%matplotlib inline\n",
        "import matplotlib.pyplot as plt\n",
        "import matplotlib as mpl\n",
        "\n",
        "plt.ion()\n",
        "np.set_printoptions(linewidth=150)\n"
      ],
      "execution_count": 8,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "97l0VhKaNoxO"
      },
      "source": [
        "from gumbel_max_causal_gadgets import coupling_util\n",
        "from gumbel_max_causal_gadgets import gadget_1\n",
        "from gumbel_max_causal_gadgets import gadget_2\n",
        "from gumbel_max_causal_gadgets import experiment_util"
      ],
      "execution_count": 9,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "DwymhgwqDNcl"
      },
      "source": [
        "## The Gumbel-max causal mechanism and Gumbel-max coupling"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "PV4IedSWDiMx"
      },
      "source": [
        "We start with the Gumbel-max causal mechanism, as introduced in \"Counterfactual off-policy evaluation with Gumbel-max structural causal models\" [(Oberst and Sontag, 2019)](http://proceedings.mlr.press/v97/oberst19a.html).\n",
        "\n",
        "Suppose we wish to sample an observation $x$ from an interventional distribution $p(x | do(y)) \\propto \\exp l_x$, defined by a vector of logits $l \\in \\mathbb{R}^k$. We can do this by first sampling a vector of Gumbel(0) exogenous noise $\\gamma$, then shifting it by $l$ and taking the argmax:"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "25eBIpVX__I6"
      },
      "source": [
        "def sample_gumbel_max(rng, logits):\n",
        "  gumbels = jax.random.gumbel(rng, logits.shape)\n",
        "  x = jnp.argmax(gumbels + logits)\n",
        "  return x"
      ],
      "execution_count": 10,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "8zVCRb-EH0XT"
      },
      "source": [
        "If we wish to jointly sample two outcomes under two interventions, we can do so by passing two different logit vectors while re-using the same `gumbels`. Because `rng` determines the samples of Gumbels, we can do this by passing the same `rng` value:"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "NsO9nX1yHz0D",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "0f26625c-3a15-45ba-b6ad-c898ee8cf7c7"
      },
      "source": [
        "# Two fairly arbitrary logit vectors\n",
        "p_logits = 0.1 * jnp.arange(10) - (10 - 1.0) / 2\n",
        "q_logits = -p_logits\n",
        "p_logits = p_logits - jax.scipy.special.logsumexp(p_logits)\n",
        "q_logits = q_logits - jax.scipy.special.logsumexp(q_logits)\n",
        "print(\"p_probs\", jnp.exp(p_logits))\n",
        "print(\"q_probs\", jnp.exp(q_logits))"
      ],
      "execution_count": 11,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "p_probs [0.06120703 0.06764422 0.07475842 0.08262087 0.09131017 0.10091332 0.11152647 0.12325582 0.13621873 0.15054502]\n",
            "q_probs [0.15054502 0.13621877 0.12325585 0.11152647 0.10091332 0.09131017 0.08262087 0.07475844 0.06764424 0.06120703]\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "XqXur9ZlInGX",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "c1e2e6a5-3fa5-4a91-8f30-8431943b33e7"
      },
      "source": [
        "keys = jax.random.split(jax.random.PRNGKey(42), 10)\n",
        "p_samples = []\n",
        "q_samples = []\n",
        "for prng_key in keys:\n",
        "  p_samples.append(int(sample_gumbel_max(prng_key, p_logits)))\n",
        "  q_samples.append(int(sample_gumbel_max(prng_key, q_logits)))\n",
        "\n",
        "print(\"p_samples\", p_samples)\n",
        "print(\"q_samples\", q_samples)"
      ],
      "execution_count": 12,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "p_samples [9, 5, 2, 2, 2, 1, 4, 9, 7, 1]\n",
            "q_samples [0, 5, 2, 2, 2, 0, 0, 9, 7, 1]\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "QI92fupoJOgZ"
      },
      "source": [
        "Note that the samples from $p$ and $q$ are the same more often than they would be if we drew them independently. This is because they share the same exogenous noise, and only have different interventional distributions. We can repeat this for a larger number of samples to visualize the resulting *Gumbel-max coupling* between $p$ and $q$:"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 282
        },
        "id": "O9ikmsZMJHKD",
        "outputId": "0c288a75-3e95-4577-f606-bdd2c2b754a2"
      },
      "source": [
        "gm_coupling_p_q = coupling_util.joint_from_samples(\n",
        "    coupling_util.gumbel_max_sampler,\n",
        "    logits_1=p_logits,\n",
        "    logits_2=q_logits,\n",
        "    rng=jax.random.PRNGKey(42),\n",
        "    num_samples=100_000)\n",
        "plt.imshow(gm_coupling_p_q, vmin=0)\n",
        "plt.colorbar()"
      ],
      "execution_count": 13,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "<matplotlib.colorbar.Colorbar at 0x7f5b87271b90>"
            ]
          },
          "metadata": {},
          "execution_count": 13
        },
        {
          "output_type": "display_data",
          "data": {
            "image/png": "iVBORw0KGgoAAAANSUhEUgAAAS0AAAD4CAYAAAC5Z7DGAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAXO0lEQVR4nO3df4wd13ne8e+zS5GSaJs2SLewSNqkQSouZbeWzVJOlDitGdtU45gpStVSakUNCBBGo8RJYyR0iiqukBZQEFgxEDUFYSpgKcVSQDvF1mHCNFaSIkVKk5Joy5RCd025JikFMX+UsaRS/LFP/7iz9u3N7r2z4szeOzvPBxho7syZM68E6uWZM+fMkW0iIppibNgBRETMRZJWRDRKklZENEqSVkQ0SpJWRDTKojoqve4NS7zshqWV1/vyM5VXGdFIF3iJi35FV1PHB//xUp85e6VU2Se++soB21uu5n5VqSVpLbthKXf9zubK6z1yc+VVRjTSQX/pqus4c/YKXz7w5lJlx9/0v1Zc9Q0rUkvSiojRZ2CKqWGHMWdJWhEtZcwll3s8HCVJWhEtlpZWRDSGMVcaOI0vSSuixaZI0oqIhjBwpYFJq9TgUklbJB2TNClpZ91BRcT8mMKltlEysKUlaRx4EHg/cBI4JGnCdoZ6RjSYgUsN7NMq09LaBEzaPm77IvAosLXesCKibsZcKbmNkjJ9WiuBE12/TwK39BaStAPYAfDaN11fSXARUSPDldHKR6VUNmHa9i7bG21vvP4NS6qqNiJq0hkRX24bJWVaWqeA1V2/VxXHIqLRxBWuas71UJRJWoeA9ZLW0klWdwA/UWtUEVG7Tkf8Akxati9Lugc4AIwDD9k+WntkEVGrzjitBZi0AGzvB/bXHEtEzLOpBra08uXSiJaabmmV2QYZNABd0hJJjxXnD0paUxz/F5KOdG1Tkt7Z715JWhEtZcQVxkpt/XQNQL8N2ADcKWlDT7HtwDnb64AHgPsBbD9i+5223wncBTxn+0i/+yVpRbTYlFVqG6DMAPStwJ5ifx+wWVJvxXcW1/aVCdMRLWXERY+XLb5C0uGu37ts7yr2ywxA/26Z4uXeeWA5cLqrzEcoMdsmSSuipTqDS0s/bJ22vbGuWCTdArxs+2uDytaStF5+Bo68u3QGL2371ycrrxNg941ra6k3YtRVNOShzAD06TInJS0ClgFnus7fAXyuzM3SpxXRUra44rFS2wDfHYAuaTGdBDTRU2YCuLvY3wY8bnc+MSFpDPjnlOjPgjweRrTaVAUtrdkGoEu6DzhsewLYDeyVNAmcpZPYpr0XOGH7eJn7JWlFtFSnI76aFDDTAHTb93btXwBun+XaPwXeU/ZeSVoRLTXHjviRkaQV0WJXGjiNJ0kroqWmR8Q3TZJWRItNDX4zOHKStCJaqjNhOkkrIhrCiEvlp/GMjCStiJayKTNwdOQkaUW0lioZXDrfkrQiWsqkpRURDZOO+IhoDFPqA38jJ0kroqU6S4g1LwU0L+KIqMjCXaw1IhYgkxHxEdEwaWlFRGPYSksrIpqj0xGfaTwR0RjK4NJpkhhbfE3l9T60YX3ldQL8m+OHBxeao3//1r4re0cMXacjvnl9Ws1LsxFRmell7wdtg0jaIumYpElJO2c4v0TSY8X5g5LWdJ37+5L+QtJRSU9LurbfvZK0IlpqekR8ma0fSePAg8BtwAbgTkkbeoptB87ZXgc8ANxfXLsIeBj4mO2bgH8EXOp3vyStiBabYqzUNsAmYNL2cdsX6axf2Lu8/VZgT7G/D9gsScAHgK/a/gqA7TO2r/S7WTriI1rKhktTpdstKyR1d/7usr2r2F8JnOg6dxK4pef675Yp1kk8DywHbgQs6QDwRuBR27/WL5AkrYiW6jwelk5ap21vrCGMRcAPAv8QeBn4kqQnbH9ptgvyeBjRYleK+YeDtgFOAau7fq8qjs1YpujHWgacodMq+++2T9t+mc6Cr+/qd7MkrYiWmh7ycLUd8cAhYL2ktZIW01nyfqKnzARwd7G/DXjctoEDwDskXV8ksx8Gnul3s4GPh5JWA/8Z+LvFv+cu258ZdF1EjLpqpvEUfVT30ElA48BDto9Kug84bHsC2A3slTQJnKWT2LB9TtKn6SQ+A/tt/36/+5Xp07oM/ILtJyW9FnhC0n+z3TcbRsToq+ob8bb303m06z52b9f+BeD2Wa59mM6wh1IGJi3bLwAvFPvfkfQsnTcBSVoRDdZ5e7jA5x4Wo1hvBg7OcG4HsAPgWi2tILSIqNOC/9yypNcAnwd+zvbf9J4vxmzsAlg2ttyVRRgRtVmwS4hJuoZOwnrE9hfqDSki5kNTJ0yXeXsoOj3/z9r+dP0hRcR8WagfAbwVuAt4WtKR4tgvF28LIqKhbHF5ISYt238ODXzwjYiBFuTjYUQsTAu2TysiFq4krYhojAU/TisiFp4FO05rzsbH0NLrq6/3wivV1wn8hxur/0zQ3hN/VnmdAHetvrWWeqN9bLhc/iOAIyMtrYgWy+NhRDRG+rQionGcpBURTZKO+IhoDDt9WhHRKOJK3h5GRJM0sU+reWk2IipR4Wo8SNoi6ZikSUk7Zzi/RNJjxfmDxVeQkbRG0v+VdKTY/tOge6WlFdFW7vRrXS1J48CDwPvprGN4SNJEz+I324FzttdJugO4H/hIce4btt9Z9n5paUW02BQqtQ2wCZi0fdz2ReBRYGtPma3AnmJ/H7C5+MDonCVpRbSUi474MtsAK4ETXb9PFsdmLGP7MnAeWF6cWyvpKUl/JumHBt0sj4cRLTaHx8MVkg53/d5VLGZztV4A3mz7jKR3A/9F0k0zLZ4zLUkrosXm8PbwtO3ZvixwCljd9XtVcWymMiclLQKWAWdsG3ilE4ufkPQN4EbgMLPI42FES9mdpFVmG+AQsF7SWkmL6Sx5P9FTZgK4u9jfBjxu25LeWHTkI+mtwHrgeL+bpaUV0WJVjIi3fVnSPcABYBx4yPZRSfcBh21P0FnRa6+kSeAsncQG8F7gPkmXgCngY7bP9rtfklZEi1Ux5KFTj/cD+3uO3du1fwG4fYbrPk9nTdXSkrQiWsqIqUzjiYgmqaihNa+StCLays2ce5ikFdFmDWxqJWlFtFhaWtM0hq67rvp6x8arrxNgrPrOyJ9c977K6wQ48PzBWur94A2l56vGAmFgaipJKyKawkBaWhHRJFWN05pPSVoRbZakFRHNUWpe4chJ0opos7S0IqIxDM7bw4holuYlrdIDlCSNF59E/WKdAUXEPHLJbYTMZVTlx4Fn6wokIoZgoSYtSauAHwU+W284ETFvpgeXltlGSNk+rd8AfhF47WwFJO0AdgBcOz5rsYgYIU0cXDqwpSXpQ8Bf236iXznbu2xvtL1x8VgN8w4jonpTKreNkDItrVuBD0v6J8C1wOskPWz7o/WGFhF100Jsadn+pO1VttfQ+Rj940lYEQtA2U74EUtszftAdERUpGQnfImOeElbJB2TNClp5wznl0h6rDh/UNKanvNvlvSipE8MuteckpbtP7X9oblcExEjrIKWVrFu4YPAbcAG4E5JG3qKbQfO2V4HPADc33P+08AflAk5La2INpsqufW3CZi0fdz2ReBRYGtPma3AnmJ/H7BZkgAk/TjwHHC0TMhJWhFtNbdxWiskHe7adnTVtBI40fX7ZHGMmcrYvgycB5ZLeg3wS8C/Kxt25h5GtNgc3h6etr2xhhA+BTxg+8Wi4TVQklZEm1XzZvAUsLrr96ri2ExlTkpaBCwDzgC3ANsk/RrwemBK0gXbvznbzZK0IuJqHQLWS1pLJzndAfxET5kJ4G7gL4BtdIZOGfih6QKSPgW82C9hQV1Ja0z4uiWVVzta43KHY8tbNtVS74Hnv1xLvVnlZ7RVMbjU9mVJ9wAHgHHgIdtHJd0HHLY9AewG9kqaBM7SSWyvSlpaEW1lKpuiY3s/sL/n2L1d+xeA2wfU8aky90rSimizERvtXkaSVkSLNXHuYZJWRJslaUVEoyRpRURTyHk8jIimGbEP/JWRpBXRYmlpRUSzJGlFRGOkTysiGidJKyKaRIM/8Ddy8hHAiGiUtLQi2iyPhxHRGOmIj4jGSdKKiEZJ0oqIphDNfHuYpBXRVunTiojGaWDSyjitiDZzyW0ASVskHZM0KWnnDOeXSHqsOH9Q0pri+CZJR4rtK5L+6aB71dLS8vgYU6+7rvJ6x8bqybEar6HesWZ98mPLm+tYhxO+eOpgLfV+aOW7a6m3bap4PJQ0DjwIvJ/O6tKHJE3Yfqar2HbgnO11ku4A7gc+AnwN2Fis6PMm4CuS/muxCvWM0tKKaLNqWlqbgEnbx21fBB4FtvaU2QrsKfb3AZslyfbLXQnq2jJ3S9KKaCt33h6W2YAVkg53bTu6aloJnOj6fbI4xkxliiR1HlgOIOkWSUeBp4GP9WtlQTriI9qt/OPhadu19CHYPgjcJOnvAXsk/UGxTuKM0tKKaLHp78QP2gY4Bazu+r2qODZjGUmLgGXAme4Ctp8FXgTe3u9mSVoRbVZNn9YhYL2ktZIW01nyfqKnzARwd7G/DXjctotrFgFIegvwNuCb/W6Wx8OItio5nGFgNZ03f/cAB4Bx4CHbRyXdBxy2PQHsBvZKmgTO0klsAD8I7JR0CZgC/pXt0/3ul6QV0VKiuhHxtvcD+3uO3du1fwG4fYbr9gJ753KvUo+Hkl4vaZ+kv5T0rKTvn8tNImI0VdSnNa/KtrQ+A/yh7W3FM+v1NcYUEfNlxBJSGQOTlqRlwHuBfwlQDB67WG9YETEvGpi0yjwergW+Dfy2pKckfVbS0t5CknZMDzy7dPnlygONiIqVfDQctcfDMklrEfAu4Lds3wy8BPytCZG2d9neaHvjNYvy9BjRCBVNmJ5PZZLWSeBkMWoVOvOG3lVfSBExX+YwjWdkDExatv8KOCHp+4pDm4Fn+lwSEQ3RxMfDsm8PfwZ4pHhzeBz4qfpCioh5MYKPfmWUSlq2jwD1fHApIoZnoSatiFh4qhwRP5+StCJaTFPNy1pJWhFttZD7tCJiYcrjYUQ0S5JWYXyMy69dUn21dayaU1O9ta3FU1cfxFQ9Iwh/7C231FLv/c/9j8rr/KW19cQ6ytLSiohmSdKKiMbw6E3RKSNJK6KlMk4rIprHzctaWY0nosWqmjAtaYukY5ImJf2tT1dJWiLpseL8QUlriuPvl/SEpKeLf75v0L2StCLaquy3tAYkLUnjwIPAbcAG4E5JG3qKbQfO2V4HPADcXxw/DfyY7XfQWWJs4CIXSVoRLVbR97Q2AZO2jxefY38U2NpTZiuwp9jfB2yWJNtP2X6+OH4UuE5S3/FSSVoRLTaHpLVi+nPqxbajq5qVwImu3yeLY8xUxvZl4DywvKfMPwOetP1Kv5jTER/RVmYuHfGnbdf2eSpJN9F5ZPzAoLJpaUW0WEUd8aeA1V2/VxXHZiwjaRGwDDhT/F4F/B7wk7a/MehmSVoRbVbNwhaHgPWS1hZfN74DmOgpM0Gnox1gG/C4bUt6PfD7wE7bpeZmJWlFtNT04NKrbWkVfVT3AAeAZ4HftX1U0n2SPlwU2w0slzQJ/Gu+t6LXPcA64F5JR4rt7/S7X/q0ItrKruwjgLb3A/t7jt3btX8BuH2G634V+NW53CtJK6LNmjcgPkkros0y9zAimsPU9322GiVpRbRZ83JWklZEm+XxMCIaJUuIRURzZAmx75kah4vLqq96ceU11me8po+r6UoDv49bg53rb628zh1f/3rldQLsuvGttdR7tTqDS5uXtdLSimizBv4dmKQV0WJpaUVEc6RPKyKapbq5h/MpSSuizfJ4GBGNkcVaI6JxGtjSKvURQEk/L+mopK9J+pyka+sOLCLmQTVfLp1XA5OWpJXAzwIbbb8dGKfzOdWIaDhNTZXaRknZx8NFdNYjuwRcDzw/oHxEjDrTyMGlA1tatk8Bvw58C3gBOG/7j3rLSdoxvSba5Vdeqj7SiKiUMHK5bZSUeTx8A53VYdcCNwBLJX20t5ztXbY32t64aMnS6iONiOrZ5bYBJG2RdEzSpKSdM5xfIumx4vxBSWuK48sl/YmkFyX9ZpmQy3TE/wjwnO1v274EfAH4gTKVR8SIqyBpSRoHHgRuAzYAd0ra0FNsO3DO9jrgAToLswJcAP4t8ImyIZdJWt8C3iPpekkCNtNZJigimmy6T6vM1t8mYNL2cdsXgUfpPJ112wrsKfb3AZslyfZLtv+cTvIqpUyf1sHiJk8CTxfX7Cp7g4gYXRW9PVwJnOj6fbI4NmOZYp3E88DyVxNzqbeHtn8F+JVXc4OIGFXl+qsKKyQd7vq9y/ZQGi8ZER/RVmYuSeu07Y2znDsFrO76vao4NlOZk5IWAcuAM+WD/Z5SI+IjYoGqpk/rELBe0lpJi+kMPp/oKTMB3F3sbwMet1/dWIq0tCJarIoxWLYvS7oHOEBnxsxDto9Kug84bHsC2A3slTQJnKVrVo2kbwKvAxZL+nHgA7afme1+SVoRbVbRwFHb+4H9Pcfu7dq/ANw+y7Vr5nKvJK2ItrKhgQul1JK0PCYuvqaO7rJ6cuw1Y6q8Tl2q5w/DWE1fmqz+v0Dh8uV66q1hasmut62vvE6AH/7qi5XX+cxHKvrzNWJTdMpISyuizZK0IqIxDOQb8RHRHAanTysimsKkIz4iGiZ9WhHRKElaEdEcc5owPTKStCLaysCILVpRRpJWRJulpRURzZFpPBHRJAZnnFZENEpGxEdEo6RPKyIaw87bw4homLS0IqI5jK9cGXYQc5akFdFW+TRNRDROA4c8ZAmxiJYy4CmX2gaRtEXSMUmTknbOcH6JpMeK8wclrek698ni+DFJHxx0ryStiLZy8RHAMlsfksaBB4HbgA3AnZI29BTbDpyzvQ54ALi/uHYDneXEbgK2AP+xqG9WSVoRLeYrV0ptA2wCJm0ft30ReBTY2lNmK7Cn2N8HbJak4vijtl+x/RwwWdQ3q1r6tF4+e/L0oYc/8b9LFF0BnK4jhpo0Kd4mxQrNindOsf7xO2qJ4S1XW8F3OHfgj71vRcni10o63PV7l+1dxf5K4ETXuZPALT3Xf7dMsbjreWB5cfx/9ly7sl8g9SwhZr+xTDlJh21vrCOGOjQp3ibFCs2Kt0mx9mN7y7BjeDXyeBgRV+sUsLrr96ri2IxlJC0ClgFnSl77/0nSioirdQhYL2mtpMV0OtYnespMAHcX+9uAx227OH5H8XZxLbAe+HK/mw17nNauwUVGSpPibVKs0Kx4mxRr7Yo+qnuAA8A48JDto5LuAw7bngB2A3slTQJn6SQ2inK/CzwDXAZ+2nbfnn+5gXOPIqK98ngYEY2SpBURjTK0pDVo2P+okLRa0p9IekbSUUkfH3ZMZUgal/SUpC8OO5Z+JL1e0j5JfynpWUnfP+yY+pH088Wfg69J+pyka4cdU9sMJWmVHPY/Ki4Dv2B7A/Ae4KdHONZuHweeHXYQJXwG+EPbbwP+ASMcs6SVwM8CG22/nU6n8x3Djap9htXSKjPsfyTYfsH2k8X+d+j8T9V3xO6wSVoF/Cjw2WHH0o+kZcB76bxZwvZF2/9nuFENtAi4rhhrdD3w/JDjaZ1hJa2Zhv2PdCIAKGam3wwcHG4kA/0G8IvAqH93ZC3wbeC3i0fZz0paOuygZmP7FPDrwLeAF4Dztv9ouFG1TzriS5L0GuDzwM/Z/pthxzMbSR8C/tr2E8OOpYRFwLuA37J9M/ASMMr9m2+g80SwFrgBWCrpo8ONqn2GlbTmPHR/mCRdQydhPWL7C8OOZ4BbgQ9L+iadx+73SXp4uCHN6iRw0vZ0y3UfnSQ2qn4EeM72t21fAr4A/MCQY2qdYSWtMsP+R0Lx+YzdwLO2Pz3seAax/Unbq2yvofPf9XHbI9kasP1XwAlJ31cc2kxnZPSo+hbwHknXF38uNjPCLw4WqqFM45lt2P8wYinhVuAu4GlJR4pjv2x7/xBjWkh+Bnik+MvrOPBTQ45nVrYPStoHPEnnrfJTZErPvMs0noholHTER0SjJGlFRKMkaUVEoyRpRUSjJGlFRKMkaUVEoyRpRUSj/D/0yGF2clFgiQAAAABJRU5ErkJggg==\n",
            "text/plain": [
              "<Figure size 432x288 with 2 Axes>"
            ]
          },
          "metadata": {
            "needs_background": "light"
          }
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "U-FGD1BdOQL4"
      },
      "source": [
        "### The Gumbel-max coupling v.s. a maximal coupling"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "tCvj0yDzKTzE"
      },
      "source": [
        "Notice that a lot of mass is on the diagonal. We might wonder if the diagonal contains as much mass as possible, e.g. is this a maximal coupling? However, as we show in Section 4, the answer is no.\n",
        "\n",
        "For comparison, we can construct a maximal coupling (which does not correspond to a causal mechanism, but is instead defined directly with respect to `p_logits` and `q_logits`):"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 282
        },
        "id": "f081v2USKDZS",
        "outputId": "38e7e6f9-bc89-417f-e5c0-82becf077d8c"
      },
      "source": [
        "maximal_coupling_p_q = coupling_util.joint_from_samples(\n",
        "    coupling_util.maximal_coupling_sampler,\n",
        "    logits_1=p_logits,\n",
        "    logits_2=q_logits,\n",
        "    rng=jax.random.PRNGKey(42),\n",
        "    num_samples=100_000)\n",
        "plt.imshow(maximal_coupling_p_q, vmin=0)\n",
        "plt.colorbar()"
      ],
      "execution_count": 14,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "<matplotlib.colorbar.Colorbar at 0x7f5b86cae550>"
            ]
          },
          "metadata": {},
          "execution_count": 14
        },
        {
          "output_type": "display_data",
          "data": {
            "image/png": "iVBORw0KGgoAAAANSUhEUgAAAS0AAAD4CAYAAAC5Z7DGAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAATkklEQVR4nO3df8xeZX3H8fenv2mB4kCMtDiaFDVFnWhFlOk2648ymV0ySMDomCHplojij8WBf6Aj/sPiQLcxk4YfIUAEU1nSYGedovFHXKX8GFgK2bOyQQsGWhD5YWmf5/nsj/uUPXtsn/s89pznvq/nfF7JCec+93Vf5xso317Xda5zXbJNREQp5gw6gIiI6UjSioiiJGlFRFGStCKiKElaEVGUea1UeuxiL3jVcc3XO7Kv8TojSrSPF9jvl3QkdXzgj5Z479Njtcreff9LW2yvPZL7NaWVpLXgVcdx6lUXNV7viesearzOiBJt9feOuI69T4/xsy2vqVV27qv/84QjvmFDWklaETH8DIwzPugwpi1JK6KjjDnget3DYZKkFdFhaWlFRDGMGSvwNb4krYgOGydJKyIKYWCswKRVa3KppLWSHpY0IunStoOKiJkxjmsdw6RvS0vSXOAa4H3ALuAuSZtsP9h2cBHRHgMHChzTqtPSOgMYsb3T9n7gVmBdu2FFRNuMGat5DJM6Y1rLgMcmfN4FvH1yIUnrgfUA8195bCPBRUSLDGPDlY9qaeyFadsbbK+2vXre0iVNVRsRLenNiK93DJM6La3dwMkTPi+vrkVE0cQYR/TO9UDUSVp3AadKWkEvWZ0PfLjVqCKidb2B+FmYtGyPSroY2ALMBa63vb31yCKiVb15WrMwaQHY3gxsbjmWiJhh47OxpRURs9OsbmlFxOxjxFiBK64naUV0WLqHEVEMI/Z77qDDmLYkrYiO6k0uTfewV+nIvlY2oTjmR+2srf/cu/a0Um/EsMtAfEQUwxZjTksrIgoynpZWRJSiNxBfXgooL+KIaESpA/HlRRwRjRmzah399FuSXdJCSbdV32+VdEp1fb6kGyU9IGmHpMv63StJK6KjDs6Ir3NMZcKS7GcDq4ALJK2aVOwi4BnbK4GrgSur6+cBC22/EXgr8JcHE9rhJGlFdNi459Q6+qizJPs64MbqfCOwRpLo9VKXSJoHHAXsB3411c2StCI6qvfCdO2W1gmStk041k+o6lBLsi+bdLuXy9geBZ4FjqeXwF4AngAeBb5s++mp4s5AfERHGXGg/ms8e2yvbiGMM4Ax4CTgFcCPJH3X9s7D/SAtrYiOsmHMc2odfdRZkv3lMlVXcCmwl94qyN+2fcD2k8BPgCmTY5JWRGeJ8ZpHHy8vyS5pAb0l2TdNKrMJuLA6Pxe407bpdQnfAyBpCXAmMOU7gOkeRnSUoZHXeA63JLukK4BttjcB1wE3SRoBnqaX2KD31PEGSdsBATfYvn+q+yVpRXRYU4sAHmpJdtuXTzjfR296w+TfPX+o61NJ0oroKKMsAhgR5ehtIVZeCigv4ohoyOzdrDUiZiFDndnuQydJK6LD0tKKiGLYSksrIsrRG4jPbjwRUYysEd+6tnbNedt9Y43Xedeby/sbLLqlNxCfMa2IKEhTM+JnUpJWREdlRnxEFKfEjS2StCI6yoYD40laEVGIXvcwSSsiCpIZ8RFRjFKnPPRtG0o6WdL3JT0oabukS2YisIhom5raQmxG1WlpjQKftX2PpGOAuyX9m+0HW44tIlpWY/33odM3adl+gt6eZNh+TtIOenuYJWlFFKz39LC8NzemNaZVbVd9OrD1EN+tB9YDLGJxA6FFRJtm/eRSSUcD3wQ+Zfs3tq22vQHYAHCsfseNRRgRrZmV3UMASfPpJaxbbN/ebkgRMRNKfXrYN2lJEr09y3bYvqr9kCJipgzbk8E66rS0zgI+Cjwg6b7q2uerfc4iolC2GJ2NScv2j6HAjm9E9DUru4cRMTvN2jGtiJi9krQiohizfp5WRMw+s3ae1mzXxiYUnxx5qPE6Af5h5etbqTe6x4bRLAIYESVJ9zAiipExrYgojpO0IqIkGYiPiGLYGdOKiKKIsTw9jIiSZEwrIopR6ruH5bUNI6IZ7o1r1Tn6kbRW0sOSRiRdeojvF0q6rfp+a7V0+8Hv3iTpp9VuXw9IWjTVvZK0IjpsHNU6piJpLnANcDawCrhA0qpJxS4CnrG9ErgauLL67TzgZuCvbJ8G/CFwYKr7JWlFdJSrgfg6Rx9nACO2d9reD9wKrJtUZh1wY3W+EVhTrYr8fuB+2/8BYHuv7bGpbpakFdFh0+geniBp24Rj/YRqlgGPTfi8q7rGocrYHgWeBY4HXgtY0hZJ90j6XL+YMxAf0WHTeHq4x/bqFkKYB/w+8DbgReB7ku62/b3D/SAtrYiO6rWiVOvoYzdw8oTPy6trhyxTjWMtBfbSa5X90PYe2y8Cm4G3THWzJK2IDhu3ah193AWcKmmFpAXA+cCmSWU2ARdW5+cCd9o2sAV4o6TFVTL7A/rsXp/uYUSH1ZnO0L8Oj0q6mF4Cmgtcb3u7pCuAbbY30duG8CZJI8DT9BIbtp+RdBW9xGdgs+1vTXW/JK2IjjJivKHXeKotBTdPunb5hPN9wHmH+e3N9KY91JKkFdFhDTS0ZlySVkRXOe8eRkRpCmxqJWlFdFhaWvGytnbNuWP33a3Ue86yt7ZSbwwvA+PjSVoRUQoDaWlFREmamKc105K0IrosSSsiylHrvcKhk6QV0WVpaUVEMQzO08OIKEt5Sav225KS5kq6V9IdbQYUETPINY8hMp1XvC8BdrQVSEQMwGxNWpKWAx8Erm03nIiYMQcnl9Y5hkjdMa2vAJ8DjjlcgWqh+/UAi1h85JFFROtKnFzat6Ul6RzgSdtTvvRme4Pt1bZXz2dhYwFGRIvGVe8YInVaWmcBH5L0x8Ai4FhJN9v+SLuhRUTbNBtbWrYvs73c9in01nW+MwkrYhaoOwg/ZIkt87QiOmv4BtnrmFbSsv0D4AetRBIRM2/IWlF1pKUV0WXjgw5g+pK0IroqiwBGRGlKfHqYpBXRZQUmrWa2l42ImCGttLQ0Zw5zjj7sGz+/vbGx5usE3EK9PjDaeJ0A5yxf3Uq9Wx6/t5V6P3DSm1upN5qR7mFElMMM3Ss6dSRpRXRZWloRUZJ0DyOiLElaEVGUJK2IKIWc7mFElCZPDyOiJGlpRURZkrQiohgZ04qI4hSYtPLCdESHabze0bceaa2khyWNSLr0EN8vlHRb9f1WSadM+v41kp6X9Nf97pWkFRFHRNJc4BrgbGAVcIGkVZOKXQQ8Y3slcDVw5aTvrwL+tc79krQiuqyZ3XjOAEZs77S9H7gVWDepzDrgxup8I7BGkgAk/SnwCLC9TshJWhFd5f+bYNrvAE6QtG3CsX5CTcuAxyZ83lVd41BlbI8CzwLHSzoa+Bvgb+uGnYH4iC6rPxC/x3Ybi7l9Ebja9vNVw6uvJK2ILmvm6eFu4OQJn5dX1w5VZpekecBSYC/wduBcSX8HHAeMS9pn+58Od7MkrYiOEvWeDNZwF3CqpBX0ktP5wIcnldkEXAj8FDiX3k71Bt71cjzSF4Hnp0pYkKQV0V0NTS61PSrpYmALMBe43vZ2SVcA22xvAq4DbpI0AjxNL7H9VpK0IrqsocmltjcDmyddu3zC+T7gvD51fLHOvZK0IrqswBnx7SStBfPRSa9qvFq9uK/xOgG876XmK32phTqB8RdfbKXetnbN+exIrak30/b3K09rpd6uybuHEVGWJK2IKIYbe3o4o5K0IrosLa2IKEnGtCKiLElaEVGMeis4DJ0krYiOEmV2D2stTSPpOEkbJT0kaYekd7QdWES0bxpL0wyNui2trwLftn2upAXA4hZjioiZMmQJqY6+SUvSUuDdwF8AVCsT7m83rIiYEQUmrTrdwxXAU8ANku6VdK2kJZMLSVp/cFXD/aPtvGoSEQ2a3sqlQ6NO0poHvAX4mu3TgReA39htw/YG26ttr14wL73HiCI0s0b8jKqTtHYBu2xvrT5vpJfEIqJwTW0hNpP6Ji3bvwAek/S66tIa4MFWo4qIGVFi97Du08NPALdUTw53Ah9rL6SImBFD2PWro1bSsn0f0MZOHBExSLM1aUXE7FPqjPgkrYgO03h5WStJK6KrZvOYVkTMTukeRkRZkrR6PG8OB048pvF65z6/sPE6AeY8/+vmK/1l81UCre3yw+hoK9W2tWvOu+5vfmemH71pUeN1Dru0tCKiLElaEVGM7MYTESXJPK2IKI/Ly1pJWhEdlpZWRJQjk0sjojQZiI+IoiRpRUQ5TAbiI6IsGYiPiLIkaUVEKTK5NCLKYmcRwIgoTHk5q9a+hxExSzW1hZiktZIeljQi6Tc2c5a0UNJt1fdbJZ1SXX+fpLslPVD98z397pWkFdFVBsZd75iCpLnANcDZwCrgAkmrJhW7CHjG9krgauDK6voe4E9svxG4ELipX9hJWhFd5prH1M4ARmzvtL0fuBVYN6nMOuDG6nwjsEaSbN9r+/Hq+nbgKElTrvaZpBXRYdPoHp4gaduEY/2EapYBj034vKu6xqHK2B4FngWOn1Tmz4B7bE+5PG8G4iM6bBpPD/fYbm3DZkmn0esyvr9f2bS0Irqqbtewf17bDZw84fPy6tohy0iaBywF9laflwP/Avy57f/qd7NWWlrj8+fw4qub34Ri0d52cuyC8ebfGp2z/0DjdQLohRdaqbe0J99tbEKx5IevbLxOgBfe/VQr9R6p3uTSRv7L3wWcKmkFveR0PvDhSWU20Rto/ylwLnCnbUs6DvgWcKntn9S5WVpaEV02XvOYQjVGdTGwBdgBfMP2dklXSPpQVew64HhJI8BngIPTIi4GVgKXS7qvOk6c6n4Z04rosIZaWtjeDGyedO3yCef7gPMO8bsvAV+azr2StCK6KiuXRkRZ8u5hRJQmiwBGRDGyWWtEFKfAllatKQ+SPi1pu6SfS/q6pOYnyUTEzGtmcumM6pu0JC0DPgmstv0GYC69yWMRUTiNj9c6hknd7uE8em9fHwAWA4/3KR8Rw870nTg6jPq2tGzvBr4MPAo8ATxr+zuTy0laf/AN8AMvtfOqSUQ0Rxi53jFM6nQPX0FvLZwVwEnAEkkfmVzO9gbbq22vnr9wSfORRkTz7HrHEKkzEP9e4BHbT9k+ANwOvLPdsCJiRhSYtOqMaT0KnClpMfBrYA2wrdWoIqJ9hY5p9U1atrdK2gjcA4wC9wIb2g4sIto3bE8G66j19ND2F4AvtBxLRMyo4ev61ZEZ8RFdZZK0IqIw5fUOk7QiumzY5mDVkaQV0WVJWhFRDBvGyusftrQbD7x4YvN7ZljzG68TQGNHNV7ngv2jjdcJgNROvdHarjlP3/Haxuscu+THzVSUllZEFCVJKyKKYSBrxEdEOQzOmFZElMJkID4iCpMxrYgoSpJWRJQjL0xHREkMzNalaSJilkpLKyLKkdd4IqIkBmeeVkQUJTPiI6IoGdOKiGLYeXoYEYVJSysiymE8NjboIKYtSSuiq7I0TUQUp8ApD82viRwRRTDgcdc6+pG0VtLDkkYkXXqI7xdKuq36fqukUyZ8d1l1/WFJH+h3ryStiK5ytQhgnWMKkuYC1wBnA6uACyStmlTsIuAZ2yuBq4Erq9+uAs4HTgPWAv9c1XdYSVoRHeaxsVpHH2cAI7Z32t4P3Aqsm1RmHXBjdb4RWCNJ1fVbbb9k+xFgpKrvsFoZ0/r1k7v23P+Pn/mfGkVPAPa0EUNLSoq3pFihrHinF+sHW4nhd4+0gud4Zst3vfGEmsUXSdo24fMG2xuq82XAYxO+2wW8fdLvXy5je1TSs8Dx1fV/n/TbZVMF0krSsv3KOuUkbbO9uo0Y2lBSvCXFCmXFW1KsU7G9dtAx/DbSPYyII7UbOHnC5+XVtUOWkTQPWArsrfnb/ydJKyKO1F3AqZJWSFpAb2B906Qym4ALq/NzgTttu7p+fvV0cQVwKvCzqW426HlaG/oXGSolxVtSrFBWvCXF2rpqjOpiYAswF7je9nZJVwDbbG8CrgNukjQCPE0vsVGV+wbwIDAKfNz2lCP/coHvHkVEd6V7GBFFSdKKiKIMLGn1m/Y/LCSdLOn7kh6UtF3SJYOOqQ5JcyXdK+mOQccyFUnHSdoo6SFJOyS9Y9AxTUXSp6s/Bz+X9HVJiwYdU9cMJGnVnPY/LEaBz9peBZwJfHyIY53oEmDHoIOo4avAt22/Hvg9hjhmScuATwKrbb+B3qDz+YONqnsG1dKqM+1/KNh+wvY91flz9P6nmnLG7qBJWk5vHva1g45lKpKWAu+m92QJ2/tt/3KwUfU1Dziqmmu0GHh8wPF0zqCS1qGm/Q91IgCo3kw/Hdg62Ej6+grwOWDY1x1ZATwF3FB1Za+VtGTQQR2O7d3Al4FHgSeAZ21/Z7BRdU8G4muSdDTwTeBTtn816HgOR9I5wJO27x50LDXMA94CfM326cALwDCPb76CXo9gBXASsETSRwYbVfcMKmlNe+r+IEmaTy9h3WL79kHH08dZwIck/Te9bvd7JN082JAOaxewy/bBlutGeklsWL0XeMT2U7YPALcD7xxwTJ0zqKRVZ9r/UKiWz7gO2GH7qkHH04/ty2wvt30KvX+vd9oeytaA7V8Aj0l6XXVpDb2Z0cPqUeBMSYurPxdrGOIHB7PVQF7jOdy0/0HEUsNZwEeBByTdV137vO3NA4xpNvkEcEv1l9dO4GMDjuewbG+VtBG4h95T5XvJKz0zLq/xRERRMhAfEUVJ0oqIoiRpRURRkrQioihJWhFRlCStiChKklZEFOV/AW4hr2B1IuFbAAAAAElFTkSuQmCC\n",
            "text/plain": [
              "<Figure size 432x288 with 2 Axes>"
            ]
          },
          "metadata": {
            "needs_background": "light"
          }
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "biN_LgVmNrdM"
      },
      "source": [
        "We can plot the difference between the two, which reveals that the Gumbel-max sampler assigns less mass to the diagonal, and more mass to the off diagonal elements."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 282
        },
        "id": "YK_JgAsUMxyo",
        "outputId": "40888e63-475f-4545-9888-fcab65e91f8d"
      },
      "source": [
        "difference = gm_coupling_p_q - maximal_coupling_p_q\n",
        "plt.imshow(difference, vmin=-0.016, vmax=0.016, cmap=\"RdBu\")\n",
        "plt.colorbar()"
      ],
      "execution_count": 15,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "<matplotlib.colorbar.Colorbar at 0x7f5b86be8450>"
            ]
          },
          "metadata": {},
          "execution_count": 15
        },
        {
          "output_type": "display_data",
          "data": {
            "image/png": "iVBORw0KGgoAAAANSUhEUgAAATwAAAD4CAYAAABxC1oQAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAYd0lEQVR4nO3df7BdZX3v8ffnnEOQgMGQtIhJKsyQUlM6omaAlk6v12CIrWPoXNDYqmkvNL1XqdbrvQzUqzhwuQNeW2pvucykEButFpxox+hQkB/SjjOFchQUARkygHDS8CMJBlBCOOd87h97hW4O+5y9Dnvvs885z+c1syZrPftZz/5uCF+eZz1rPUu2iYgowUC/A4iImClJeBFRjCS8iChGEl5EFCMJLyKKMdSLRpcsWeoVb/ylrreb7BzR8Oijj7J792510sbAouVmdH+tun5+z42213XyfbNBTxLeijf+Erf+03e73u5hQx39+42YN0477bTOGxndz9AJ76lV9cW7v7C08y/sv54kvIiYAyQ0MNjvKGZUEl5EscTA0IJ+BzGjclksolRVD6/O1r4prZP0gKQdki5o8fmhkq6rPr9D0rFV+RJJ35H0nKS/nnDObVWbd1fbL3b6k9PDiyiUAA12PqSVNAhcCbwTGAHulLTd9n1N1c4BnrZ9vKQNwOXA+4D9wKeAE6ttot+3PdxxkJX08CJKJTEwMFhra+NkYIfth2wfAK4F1k+osx7YWu1vA9ZIku2f2f4ujcTXc0l4EQWbxpB2qaThpm1TUzPLgMeajkeqMlrVsT0K7AOW1AjxC9Vw9lOSOr5No9aQVtI64PPAIHC17cs6/eKI6LPpzdLutr26l+G08Pu2d0p6LfA14IPAFztpsG0Pr2l8/i5gFfB+Sas6+dKI6D8hBoYOqbW1sRNY0XS8vCprWUfSEHAksGeqRm3vrP58FvgKjaFzR+oMaeuMzyNiruneLO2dwEpJx0laAGwAtk+osx3YWO2fBdzqKRbjlDQkaWm1fwjwbuBHr+JXvkydIW2r8fkpLQLcBGwCWL5ixcSPI2IW6saNx7ZHJZ0H3EjjstcW2/dKuhgYtr0duAb4kqQdwF4aSbERg/QIsAhYIOlMYC3wE+DGKtkNAjcDf9NprF27LcX2ZmAzwElvfWuWUY6Y7aSu3JYCYPt64PoJZZ9u2t8PnD3JucdO0uzbuhJckzoJr874PCLmGNGdHt5cUifhvTQ+p5HoNgC/19OoIqL3NMBgYY+WtU14k43Pex5ZRPSW0sNrqdX4PCLmNpHVUiKiIEl4EVGGrIcXEeVIwouIQkhi4JDM0kZECTKknd0GXniuJ+2OH3pET9qNmO2S8CKiGAMDZb0JMAkvolCSUBJeRJRicLCsRc+T8CJKJdLDi4gyNFZLScKLiCKIgc7fizOnJOFFlCpD2ogoSRJeRBRBgsGhJLyIKEQX3m09pyThRRRKUp60iIhy5BpeRBQjCS8iyiByH15ElEGIgaGynqUt69dGxL9TY3moOlvbpqR1kh6QtEPSBS0+P1TSddXnd0g6tipfIuk7kp6T9NcTznmbpHuqc/5KXZhSTsKLKJikWlubNgaBK4F3AauA90taNaHaOcDTto8HrgAur8r3A58C/nuLpq8C/ghYWW3rXuXPfEkSXkShGosH1NvaOBnYYfsh2weAa4H1E+qsB7ZW+9uANZJk+2e2v0sj8f17bNIxwCLbt9s28EXgzI5+MLmGF1EuTWvF46WShpuON9veXO0vAx5r+mwEOGXC+S/VsT0qaR+wBNg9yfctq9ppbnNZ3WAnk4QXUSwxUH8B0N22V/cympmQhBdRKE2vhzeVncCKpuPlVVmrOiOShoAjgT1t2lzeps1p60nCGxuHfS+Mdb3d8QULu94mwKIf3ND1Nsfe3PH11Yie69KNx3cCKyUdRyMpbQB+b0Kd7cBG4F+As4Bbq2tzLdneJekZSacCdwAfAv5vp4GmhxdRKAkGu5Dwqmty5wE3AoPAFtv3SroYGLa9HbgG+JKkHcBeGkmxikOPAIuABZLOBNbavg/4MPC3wGHAP1ZbR5LwIgrWjYQHYPt64PoJZZ9u2t8PnD3JucdOUj4MnNiVACtJeBGFEupawpsrkvAiCiXBgsIeLUvCiyiUBEPp4UVECUT3ruHNFUl4EaVSruG9gqQVNJ5jOxowjUdKPt/rwCKitxo9vFzDm2gU+ITt70t6LfA9STdV98lExByWHt4EtncBu6r9ZyXdT+Mh3iS8iDlsQMos7VSqRfveQuNRj4mfbQI2Abxh+YqJH0fELDRY2BLvtdO7pCOArwF/avuZiZ/b3mx7te3VRy1Z2s0YI6IHDj5aVmebL2r18CQdQiPZfdn213sbUkTMlPmUzOqoM0srGg/+3m/7L3ofUkTMhNx43NppwAeBeyTdXZX9WfWwcETMUSKTFq9QrTdf1v8GIgrQreWh5pI8aRFRqDxaFhHlSA8vIkqR9fAioihJeF3w4tg4I8+80PV2jz5iQdfbBBh709qutzl+RasXqXfudR//XE/ajfIMZAHQiChGruFFRCmEinuWNgkvomADSXgRUQIBg2XluyS8iGIJBnINLyJKIOCQwpZ4L+vXRsRLDg5p62xt25LWSXpA0g5JF7T4/FBJ11Wf31EtJnzwswur8gckndFU/oikeyTdLWm4G785PbyIUkldGdJKGgSuBN4JjAB3Sto+4b035wBP2z5e0gbgcuB9klYBG4BfBd4A3Czpl22PVef9R9u7Ow6ykh5eRKFEY5a2ztbGycAO2w/ZPgBcC6yfUGc9sLXa3wasqdbaXA9ca/sF2w8DO6r2eiIJL6Jg0xjSLpU03LRtampmGfBY0/FIVUarOrZHgX3AkjbnGvi2pO9N+L5XLUPaiEJJcMhg7T7PbturexlPC79pe6ekXwRukvRj2//cSYPp4UUUqotD2p1A86sKl1dlLetIGgKOBPZMda7tg38+CfwDXRjqJuFFFKxLs7R3AislHSdpAY1JiO0T6mwHNlb7ZwG32nZVvqGaxT0OWAn8q6TDJb0WQNLhwFrgR53+3gxpIwolavXe2rI9Kuk84EZgENhi+15JFwPDtrfTeBHYlyTtAPbSSIpU9b4K3AeMAh+xPSbpaOAfGvMaDAFfsX1Dp7Em4UWUqourpVQv9bp+Qtmnm/b3A2dPcu6lwKUTyh4C3tyV4Jok4UUUqnENr99RzKwkvIhClfhoWRJeRKkE9e9KmR+S8CIKdfC2lJIk4UUUKyseR0Qh0sPrkudHx7nniee63u6+/a/pepsAxy4+rOttLv6vl3e9TYD/c9Sv9aTd/7H3np60G7NX49GyJLyIKERhHbwkvIiSDVBWxkvCiyiUSA8vIgqSJy0iogxKDy8iCqHchxcRJSltSFv7STpJg5LukvStXgYUETNHNbf5Yjo9vI8B9wOLehRLRMygEp+0qNXDk7Qc+B3g6t6GExEzSaq3zRd1h7R/CZwPjE9WQdKmg69we+7pPV0JLiJ6a6DmNl+0/S2S3g08aft7U9Wzvdn2aturj1i8pGsBRkRvqFrivc42X9S5hnca8B5Jvw28Blgk6e9sf6C3oUVEr82n4WodbXt4ti+0vdz2sTTeNHRrkl3E3CfKG9LmPryIgqmwLt60Ep7t24DbehJJRMwslXfjcXp4EYUSUNj6n0l4ESUrbUg7n65HRsQ0HHwRd52tbVvSOkkPSNoh6YIWnx8q6brq8zskHdv02YVV+QOSzqjb5quRhBdRsG48SytpELgSeBewCni/pFUTqp0DPG37eOAK4PLq3FU07v74VWAd8P+q5/brtDltSXgRxRIDqre1cTKww/ZDtg8A1wLrJ9RZD2yt9rcBa9QYT68HrrX9gu2HgR1Ve3XanLaeXMM7ZEAc89pDu95uL9oEerIm2Ivj7nqbAB9+4gc9afcTC9/Uk3b//Of396Td6ILpPSe7VNJw0/Fm25ur/WXAY02fjQCnTDj/pTq2RyXtA5ZU5bdPOHdZtd+uzWnLpEVEoWSj8bG61XfbXt3LeGZCEl5EweRJ1wOZjp3Aiqbj5VVZqzojkoaAI4E9bc5t1+a05RpeRLEMHq+3Te1OYKWk4yQtoDEJsX1Cne3Axmr/LBqPqLoq31DN4h4HrAT+tWab05YeXkTJ3Pm15uqa3HnAjcAgsMX2vZIuBoZtbweuAb4kaQewl0YCo6r3VeA+YBT4iO0xgFZtdhprEl5Eqew6vbeaTfl64PoJZZ9u2t8PnD3JuZcCl9Zps1NJeBEF69I1vDkjCS+iWIbx0X4HMaOS8CJKZbo2pJ0rkvAiimUYT8KLiELkGl5ElCMJLyKKYEP9R8vmhSS8iIJlSBsRhejejcdzRRJeRMmS8CKiCF18tGyuSMKLKJTINbyIKIZhLLO0EVGCPFoWESXJkDYiCpFJi6444tBB/sMbj+x6u2NdWJ21Zbs9+Hfeqxe6vzDWm38GF/2048VkWxr96mU9aXfovV15L3Mk4UVEEfJoWUSUw3j0xX4HMaOS8CJKZdLDi4gyGOPchxcRRTBZ8TgiSpFJi4gohcubtBioU0nS6yRtk/RjSfdL+vVeBxYRvWY8PlZrmy9qJTzg88ANtn8FeDNwf+9CiogZcXCWts7WAUlHSbpJ0oPVn4snqbexqvOgpI1N5W+TdI+kHZL+Smrc1i/pM5J2Srq72n67XSxtE56kI4HfAq4BsH3A9k/r/tiImK2q1zTW2TpzAXCL7ZXALdXxy0g6CrgIOAU4GbioKTFeBfwRsLLa1jWdeoXtk6rt+naB1OnhHQc8BXxB0l2SrpZ0eIuAN0kaljS8Z/fuGs1GRF8ZPDZWa+vQemBrtb8VOLNFnTOAm2zvtf00cBOwTtIxwCLbt9s28MVJzq+lTsIbAt4KXGX7LcDPaJGhbW+2vdr26iVLl77aeCJixng6Q9qlBzs01bZpGl90tO1d1f7jwNEt6iwDHms6HqnKllX7E8sPOk/SDyVtmWyo3KzOLO0IMGL7jup4Gy0SXkTMMdObpd1te/VkH0q6GXh9i48++fKvtCV1awWMq4BLaFyNvAT4c+A/T3VC24Rn+3FJj0k6wfYDwBrgvi4EGxF91b378GyfPtlnkp6QdIztXdUQ9ckW1XYCb286Xg7cVpUvn1C+s/rOJ5q+42+Ab7WLs+4s7Z8AX5b0Q+Ak4H/XPC8iZqsZmqUFtgMHZ103At9oUedGYK2kxdXQdC1wYzUUfkbSqdXs7IcOnl8lz4N+F/hRu0Bq3Xhs+25g0u5sRMw9xnhmHi27DPiqpHOAnwDvBZC0Gvgvts+1vVfSJcCd1TkX295b7X8Y+FvgMOAfqw3gs5JOopG6HwH+uF0gedIiolQztFqK7T00LoVNLB8Gzm063gJsmaTeiS3KPzjdWJLwIkpl4xcP9DuKGZWEF1EsZ7WUiCjIPHpOto4kvIhS2fNqYYA6epLwBkdfYNFPH+56u+ML295I/eraPfQVT8p17Nnx3vy/ZN8LvfkLKnrzmrWB3z2/J+2+7r5bu97m6Kp3dL3N2W6GZmlnjfTwIkpl4168o3QWS8KLKJRtxl8c7XcYMyoJL6JUJj28iChHEl5EFME243lNY0SUIrO0EVGGzNJGRCkySxsRRRlPDy8iipDbUiKiGLmGFxGlMJmljYhS2IwfyKRFRJTAMJ4eXkSUwOQaXkSUwuA8WhYRZZix1zTOGkl4EaXKfXgRUQrbjGWWNiLKkCFtV/i5p9n/T9u63u6hb1rd9TYBOOE3u97koHrzUpznX+zNX9Bdz77Qk3YHBxb2pN3x49/e9TaP7HqLDe5Rux2boSGtpKOA64BjgUeA99p+ukW9jcD/rA7/l+2tVfmlwIeAxbaPaKp/KPBF4G3AHuB9th+ZKpaBDn9LRMxVBo+51tahC4BbbK8EbqmOX6ZKihcBpwAnAxdJOviawm9WZROdAzxt+3jgCuDydoEk4UUUypjxsfFaW4fWA1ur/a3AmS3qnAHcZHtv1fu7CVgHYPt227vatLsNWCNNPbTKNbyIUhk8Xrv3tlTScNPxZtuba557dFPCehw4ukWdZcBjTccjVdlUXjrH9qikfcASYPdkJyThRRTKhrEDtW883m170ovokm4GXt/io0++/DttSX27rJmEF1Eqd+X6XNWUT5/sM0lPSDrG9i5JxwBPtqi2E3h70/Fy4LY2X7sTWAGMSBqiMe+0Z6oTcg0vomDjY661dWg7sLHa3wh8o0WdG4G1khZXkxVrq7K67Z4F3Gp7ymCT8CJKVd2WUmfr0GXAOyU9CJxeHSNptaSrAWzvBS4B7qy2i6syJH1W0giwUNKIpM9U7V4DLJG0A/hvtJj9najWkFbSx4FzadxSdA/wh7b31/yxETELGRivP2nx6r/H3gOsaVE+TCOvHDzeAmxpUe984PwW5fuBs6cTS9senqRlwEeB1bZPBAaBDdP5koiYhWzGDozV2uaLupMWQ8Bhkl4EFgL/1ruQImImuLrxuCRte3i2dwKfAx4FdgH7bH97Yj1JmyQNSxre/czPux9pRHTXzD1pMWvUGdIupnFH83HAG4DDJX1gYj3bm22vtr166aLePD8ZEd00Y09azBp1ZmlPBx62/ZTtF4GvA7/R27AioueqJy3qbPNFnWt4jwKnSloIPE9jtmV46lMiYrYzdOMeuzmlbcKzfYekbcD3gVHgLqDuM3QRMVvZjM+jGdg6as3S2r6IxtItETFP2OnhRURBsuJxRJTBXXlOdk5JwosoVYE3HifhRRTK5DWNEVEKm7EDSXgde37Pc9z/lX/uersnnN2bd2guXLq8620u+IVf7nqbAPc88WxP2v15j96GtviwQ3rS7iED3X8rXC/aBHh+tPv/bLsxErVhfOrl4+ad9PAiCjaWhBcRJTDd6SnOJUl4EQVLDy8iijBuODCPFgaoIwkvomAZ0kZEEYwzpI2IMmTSIiKKkoQXEUWwM0sbEYUwmaWNiEKUeA2vzkt8ImKeGrNrbZ2QdJSkmyQ9WP25eJJ6G6s6D0ra2FR+qaTHJD03of4fSHpK0t3Vdm67WJLwIgrVuIZXb+vQBcAttlcCt1THLyPpKBqvkTgFOBm4qCkxfrMqa+U62ydV29XtAknCiyjYTPTwaLzXemu1vxU4s0WdM4CbbO+1/TRwE7AOwPbttnd1GgQk4UUUy8B4zQ1YKmm4ads0ja86uilhPQ4c3aLOMuCxpuORqqyd/yTph5K2SVrRrnImLSIKZTydWdrdtldP9qGkm4HXt/joky/7TtuSujVV8k3g722/IOmPafQe3zHVCUl4EYVqzNJ2J/fYPn2yzyQ9IekY27skHQM82aLaTuDtTcfLgdvafOeepsOrgc+2izND2ohSzdykxXbg4KzrRuAbLercCKyVtLiarFhblU2qSp4HvQe4v10gSXgRhTrYw5uBSYvLgHdKehA4vTpG0mpJVwPY3gtcAtxZbRdXZUj6rKQRYKGkEUmfqdr9qKR7Jf0A+CjwB+0CyZA2omAzceNxNfRc06J8GDi36XgLsKVFvfOB81uUXwhcOJ1YkvAiCjVOeY+WyT14eFjSU8BPalRdCuzuegC9M5finUuxwtyKdzbE+kbbv9BJA5JuoPFb6thte10n3zcb9CTh1f5yaXiqqe7ZZi7FO5dihbkV71yKNV4ukxYRUYwkvIgoRr8T3uY+f/90zaV451KsMLfinUuxRpO+XsOLiJhJ/e7hRUTMmCS8iChG3xKepHWSHpC0Q9IrFgScLSStkPQdSfdVj7F8rN8x1SFpUNJdkr7V71imIul11dI+P5Z0v6Rf73dMU5H08ervwY8k/b2k1/Q7pqivLwlP0iBwJfAuYBXwfkmr+hFLDaPAJ2yvAk4FPjKLY232MWo8TD0LfB64wfavAG9mFscsaRmNZzZX2z4RGAQ29DeqmI5+9fBOBnbYfsj2AeBaGquizjq2d9n+frX/LI3/IOssTNg3kpYDv0NjyZxZS9KRwG8B1wDYPmD7p/2Nqq0h4DBJQ8BC4N/6HE9MQ78S3qtd3bSvJB0LvAW4o7+RtPWXNB62Hu93IG0cBzwFfKEafl8t6fB+BzUZ2zuBzwGPAruAfba/3d+oYjoyaVGTpCOArwF/avuZfsczGUnvBp60/b1+x1LDEPBW4CrbbwF+RosXvMwW1Tpt62kk6jcAh0v6QH+jiunoV8LbCTSvP7+8KpuVJB1CI9l92fbX+x1PG6cB75H0CI1LBe+Q9Hf9DWlSI8CI7YM95m00EuBsdTrwsO2nbL8IfB34jT7HFNPQr4R3J7BS0nGSFtC48Lu9T7FMSZJoXGO63/Zf9DuedmxfaHu57WNp/HO91fas7IXYfhx4TNIJVdEa4L4+htTOo8CpkhZWfy/WMIsnWeKV+rIenu1RSefRWMJ5ENhi+95+xFLDacAHgXsk3V2V/Znt6/sY03zyJ8CXq//xPQT8YZ/jmZTtOyRtA75PY/b+LvKY2ZySR8siohiZtIiIYiThRUQxkvAiohhJeBFRjCS8iChGEl5EFCMJLyKK8f8BlnkG8ZaFKAAAAAAASUVORK5CYII=\n",
            "text/plain": [
              "<Figure size 432x288 with 2 Axes>"
            ]
          },
          "metadata": {
            "needs_background": "light"
          }
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "acU5L0pDOUW9"
      },
      "source": [
        "### Using the Gumbel-max SCM to sample counterfactuals"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "dR1AyEY8OlVz"
      },
      "source": [
        "We can use the top-down sampling algorithm [(Maddison et al. [2014])](https://arxiv.org/abs/1411.0030) to answer counterfactual queries: given that we observed $x^{(obs)}$ under `p_logits`, what would we have observed under `q_logits`?\n",
        "\n",
        "The key insight is that the maximum value and the argmax are independent for a set of independent shifted Gumbels (as explained [here](https://cmaddis.github.io/gumbel-machinery)). $x^{(obs)}$ is the argmax, so we can sample the exogenous noise by sampling the max, then filling in the rest."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 102
        },
        "id": "OOAU-pcPNlvX",
        "outputId": "825f1ed4-7fa6-41bd-e6b6-17a65a22ec10"
      },
      "source": [
        "x_obs = 7  # for example\n",
        "rng = jax.random.PRNGKey(1234)\n",
        "\n",
        "# Use jax.vmap to draw many samples\n",
        "def sample_one(key):\n",
        "  gumbels = coupling_util.counterfactual_gumbels(p_logits, x_obs, key)\n",
        "  y_for_q = jnp.argmax(gumbels + q_logits)\n",
        "  return jnp.zeros([10]).at[y_for_q].set(1.)\n",
        "\n",
        "counterfactual_y = jnp.mean(jax.vmap(sample_one)(jax.random.split(rng, 1000)), axis=0)\n",
        "plt.imshow(counterfactual_y[None, :], vmin=0)"
      ],
      "execution_count": 16,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "<matplotlib.image.AxesImage at 0x7f5b86b84e90>"
            ]
          },
          "metadata": {},
          "execution_count": 16
        },
        {
          "output_type": "display_data",
          "data": {
            "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXwAAABECAYAAACCuY6+AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAGzklEQVR4nO3dX4icVx3G8e/T3aSaVNJoBdOkmIhaDRZNu2g1UKRJ0aKkFyq0oLRiWS+sqUWw/gEvvIoi/rkQIaRK0VILsWiUWLWkxQshdtvGP01MG6M2idEmaVqTYpusebyYN+4wzmZ3887OGed9PrDs++fse34cZp+dfWfmHNkmIiKG3wWlC4iIiP5I4EdENEQCPyKiIRL4ERENkcCPiGiIBH5EREPUCnxJr5T0S0lPVd+XTtPu35J2VV/b6vQZERHnR3Xehy/pK8CztjdJ+iyw1PadXdqdtH1RjTojIqKmuoG/F3i37cOSlgEP2768S7sEfkREYXUD/znbF1fbAo6f3e9oNwnsAiaBTbZ/NM31xoFxgAtGFl61aPGrz7u2XtCZwfgUsibPlC5hcExOlq4ABuTT6R6Ax+cbr3ihdAkAPPm7RaVLGBgnOH7UdtfwnDHwJT0IvKbLqS8Ad7cHvKTjtv/nPr6k5bYPSXodsANYZ/tP5+r3FUtW+Mq1G89Z23xbcOJ00f7PGj16snQJg+PY8dIVwKnBeFyc+deLpUvggb/+pnQJALzn0reVLmFgPOitj9oe63ZudKYftr1+unOS/iFpWdstnWemucah6vt+SQ8Da4BzBn5ERPRW3bdlbgNurrZvBn7c2UDSUkkXVtuXAGuB3TX7jYiIOaob+JuA6yQ9Bayv9pE0JmlL1ebNwISk3wIP0bqHn8CPiOizGW/pnIvtY8C6LscngFur7V8DV9TpJyIi6ssnbSMiGiKBHxHREAn8iIiGSOBHRDREAj8ioiES+BERDZHAj4hoiAR+RERD9CTwJb1X0l5J+6p58TvPXyjpvur8Tkkre9FvRETMXu3AlzQCfAu4HlgN3CRpdUezj9GaOvn1wNeBL9ftNyIi5qYXz/DfDuyzvd/2KeAHwA0dbW4A7q62twLrqvnzIyKiT3oR+MuBA237B6tjXdvYngSeB17VeSFJ45ImJE2cPjUYCytERAyLgXrR1vZm22O2xxYsXFy6nIiIodKLwD8EXNa2v6I61rWNpFFgCXCsB31HRMQs9SLwHwHeIGmVpIXAjbQWRmnXvlDKB4EdrrOYbkREzFmt+fChdU9e0m3Az4ER4Du2n5D0JWDC9jbgLuB7kvYBz9L6oxAREX1UO/ABbG8Htncc+2Lb9ovAh3rRV0REnJ+BetE2IiLmTwI/IqIhEvgREQ2RwI+IaIgEfkREQyTwIyIaIoEfEdEQCfyIiIbo1wIot0g6ImlX9XVrL/qNiIjZq/1J27YFUK6jNTXyI5K22d7d0fQ+27fV7S8iIs5PvxZAiYiIwnoxl063BVDe0aXdByRdAzwJ3GH7QGcDSePAeLV78lc/u3NvzdouAY7WvMawyFhMyVhMqTUWI8t6WEkt+3pxkWF5XLx2uhM9mTxtFn4C3Gv7JUkfp7Xc4bWdjWxvBjb3qlNJE7bHenW9/2cZiykZiykZiylNGIu+LIBi+5jtl6rdLcBVPeg3IiLmoC8LoEhq/8dvA7CnB/1GRMQc9GsBlI2SNgCTtBZAuaVuv7PUs9tDQyBjMSVjMSVjMWXox0JZaTAiohnySduIiIZI4EdENMTQBv5M0z00haTLJD0kabekJyTdXrqmkiSNSHpc0k9L11KapIslbZX0R0l7JL2zdE2lSLqj+v34g6R7Jb2sdE3zYSgDv226h+uB1cBNklaXraqYSeDTtlcDVwOfaPBYANxO3iV21jeBB2y/CXgrDR0XScuBjcCY7bfQevPJjWWrmh9DGfhkuof/sn3Y9mPV9glav9TLy1ZVhqQVwPtofRak0SQtAa4B7gKwfcr2c2WrKmoUeLmkUWAR8LfC9cyLYQ38btM9NDLk2klaCawBdpatpJhvAJ8BzpQuZACsAo4A361ucW2RtLh0USXYPgR8FXgaOAw8b/sXZauaH8Ma+NFB0kXAD4FP2f5n6Xr6TdL7gWdsP1q6lgExClwJfNv2GuAFoJGvdUlaSusOwCrgUmCxpA+XrWp+DGvgzzjdQ5NIWkAr7O+xfX/pegpZC2yQ9Bdat/iulfT9siUVdRA4aPvsf3tbaf0BaKL1wJ9tH7F9GrgfeFfhmubFsAb+jNM9NIUk0bpPu8f210rXU4rtz9leYXslrcfDDttD+SxuNmz/HTgg6fLq0Dqgcw2LpngauFrSour3ZR1D+gJ2v2bL7KvppnsoXFYpa4GPAL+XtKs69nnb2wvWFIPhk8A91ZOi/cBHC9dThO2dkrYCj9F6V9vjDOk0C5laISKiIYb1lk5ERHRI4EdENEQCPyKiIRL4ERENkcCPiGiIBH5EREMk8CMiGuI/zSUr12JWF7wAAAAASUVORK5CYII=\n",
            "text/plain": [
              "<Figure size 432x288 with 1 Axes>"
            ]
          },
          "metadata": {
            "needs_background": "light"
          }
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "FkB3rq37SlYM"
      },
      "source": [
        "This is equivalent to sampling only within a single row (row 7) of the coupling matrix in the previous section.\n",
        "\n",
        "The key property that makes this useful for counterfactual inference is that it works for any `x_obs`, even one that we did not sample using our mechanism, as long as the value of `x_obs` can be viewed as a sample from the distribution given by `p_logits`. Thus, it can be used to infer counterfactual distributions for data collected offline by interacting with the real world, as described by Oberst and Sontag (2017)."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "mKkBEHX9dgCe"
      },
      "source": [
        "## Inverse-CDF couplings and monotonicity\n",
        "\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "f4uujY72dpWu"
      },
      "source": [
        "Another interesting class of coupling is the \"inverse CDF\" causal mechanism and resulting coupling. If we define an order on the outcomes, we can use this to construct the cumulative distribution function, or CDF, for any particular logit vector $l$. It turns out that inverting the CDF and evaluating it at a sample of uniform random noise will produce a sample from the desired distribution."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 282
        },
        "id": "mNB6vPWreMt3",
        "outputId": "4fb45657-d0e4-4dd0-b33c-e4ed0628960f"
      },
      "source": [
        "inverse_cdf_p_q = coupling_util.inverse_cdf_coupling(\n",
        "    logits_1=p_logits,\n",
        "    logits_2=q_logits)\n",
        "plt.imshow(inverse_cdf_p_q, vmin=0)\n",
        "plt.colorbar()"
      ],
      "execution_count": 17,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "<matplotlib.colorbar.Colorbar at 0x7f5b86abd310>"
            ]
          },
          "metadata": {},
          "execution_count": 17
        },
        {
          "output_type": "display_data",
          "data": {
            "image/png": "iVBORw0KGgoAAAANSUhEUgAAAS0AAAD4CAYAAAC5Z7DGAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAATUElEQVR4nO3df6xfdX3H8eeLthRatDhQM1ocdUWXKpvgHaI4t1mVMn90ySApRmSGpFtmFY2LA/9A15gtLAZmApo0/AgDIiyVJXfaWado5s9KoUwslXhXNlrAQH+s/HDQ9t7X/viesrvr7fd7ar/nfr+fe16P5MTv93zP/Zx3CL74nM/5nM+RbSIiSnHcoAuIiDgaCa2IKEpCKyKKktCKiKIktCKiKHObaHTeogWe/8qX9r3d4352oO9tRpToeZ7jgF/QsbRxwR8u9J6947WOve/HL2yyvfJYztcvjYTW/Fe+lLOuv6zv7Z60ckff24wo0WZ/85jb2LN3nB9telWtY+f8+s9OPeYT9kkjoRURw8/ABBODLuOoJbQiWsqYg653eThMEloRLZaeVkQUw5jxAh/jS2hFtNgECa2IKISB8QJDq9bkUkkrJT0saUzSlU0XFREzYwLX2oZJz56WpDnADcA7gV3AvZJGbT/UdHER0RwDBwsc06rT0zoXGLO9w/YB4E5gVbNlRUTTjBmvuQ2TOmNai4Gdk77vAt409SBJa4A1AMe/ov+P8EREnxnGhyuPaunbA9O219sesT0yb9GJ/Wo2IhrSmRFfbxsmdXpajwGnT/q+pNoXEUUT4xzTM9cDUSe07gXOlLSUTlitBt7faFUR0bjOQPwsDC3bhyStBTYBc4CbbW9rvLKIaFRnntYsDC0A2xuBjQ3XEhEzbGI29rQiYnaa1T2tiJh9jBgvcMX1hFZEi+XyMCKKYcQBzxl0GUctoRXRUp3Jpbk8BGD+nHGWvnRv39t9qu8tRrRbiQPx5cVsRPSFLcZ9XK2tl17LV0maL+mu6vfNks6o9s+TdKukByVtl3RVr3MltCJabALV2rqZtHzVhcBy4BJJy6ccdjmwz/Yy4Drgmmr/xcB822cBbwT+7HCgHUlCK6KlOgPxc2ttPdRZvmoVcGv1eQOwQpLoDK0tlDQXOBE4ADzd7WQJrYiWOjwQX2cDTpW0ZdK2ZlJT0y1ftXjK6V48xvYhYD9wCp0Aew54AngU+JztrgPiuXsY0WLj9edp7bY90kAJ5wLjwGnAy4DvSPqG7SO+Tj49rYiWOjwjvs7WQ53lq148proUXATsobNizNdsH7T9JPA9oGs4JrQiWmzCx9Xaenhx+SpJx9NZvmp0yjGjwGXV54uAe2ybziXh2wEkLQTOA37a7WS5PIxoqc4D08febznS8lWS1gFbbI8CNwG3SRoD9tIJNujcdbxF0jZAwC22f9ztfAmtiJYy4mCfHuOZbvkq21dP+vw8nekNU//u2en2d5PQimgpm1oTR4dNQiuitXpPHB1GCa2IljLpaUVEYbIIYEQUwyiLAEZEOTqvECsvAsqrOCL6ZPa+rDUiZiFDndnuQyehFdFi6WlFRDFspacVEeXoDMTnbTwRUQxlculhB56Yz86/fU3f291128G+twmw7NKtjbQbMcw6A/EZ04qIgmRGfEQUIzPiI6I4ecN0RBTDhoMTCa2IKETn8jChFREFyYz4iChGqVMeevYNJZ0u6VuSHpK0TdIVM1FYRDRN/XqF2Iyq09M6BHzC9v2SXgLcJ+lfbT/UcG0R0bBZuUa87SeAJ6rPz0jaDiwGEloRBevcPZzlzx5KOgM4G9g8zW9rgDUA8088uQ+lRUSTSp1cWvtiVdJJwJeBj9l+eurvttfbHrE9Mu/4hf2sMSIaMlG9RqzXNkxq9bQkzaMTWHfYvrvZkiJiJpR697BnaEkScBOw3fa1zZcUETNl2O4M1lGnp3U+cCnwoKQHqn2fsr2xubIiomm2ODQbQ8v2d2HILmojoi9m5eVhRMxOs3ZMKyJmr4RWRBSj1HlaCa2IFhu2OVh1NBJa2v8LTvjnH/W93X/4/P19bxNgHec00m7EMLPhUBYBjIiS5PIwIoqRMa2IKI4TWhFRkgzER0Qx7DLHtMq7dRARfSLGJ46rtfVsSVop6WFJY5KunOb3+ZLuqn7fXK3Nd/i335b0g2o59wclndDtXAmtiBazVWvrRtIc4AbgQmA5cImk5VMOuxzYZ3sZcB1wTfW3c4HbgT+3/TrgD4CD3c6X0IpoqcPPHtbZejgXGLO9w/YB4E5g1ZRjVgG3Vp83ACuqZa/eBfzY9r8D2N5je7zbyRJaEW3lzrhWnQ04VdKWSduaSS0tBnZO+r6r2sd0x9g+BOwHTgFeA1jSJkn3S/pkr7IzEB/RYkdx93C37ZEGSpgLvBX4XeAXwDcl3Wf7m0f6g/S0IlrK/RuIfww4fdL3JdW+aY+pxrEWAXvo9Mr+zfZu278ANkL35+oSWhEtdhSXh93cC5wpaamk44HVwOiUY0aBy6rPFwH32DawCThL0oIqzH6fHq8nzOVhRIv1Y0a87UOS1tIJoDnAzba3SVoHbLE9Suc9E7dJGgP20gk2bO+TdC2d4DOw0fZXu50voRXRUp1eVH8ml1bvjNg4Zd/Vkz4/D1x8hL+9nc60h1oSWhEtVuKM+IRWRIvVGK8aOgmtiJYyYiKLAEZESQrsaCW0IlqrjwPxMymhFdFmBXa1EloRLZaeVsPW/eYbG2l30+Nb+97mBae9oe9tRvSTgYmJhFZElMJAeloRUZLM04qIsiS0IqIcvZdSHkYJrYg2S08rIophcO4eRkRZygut2k9LSpojaaukrzRZUETMINfchsjRPOJ9BbC9qUIiYgBma2hJWgK8G7ix2XIiYsYcnlxaZxsidce0/h74JPCSIx1QvQdtDcAJLDj2yiKicSVOLu3Z05L0HuBJ2/d1O872etsjtkfmMb9vBUZEgyZUbxsidXpa5wPvk/RHwAnASyXdbvsDzZYWEU3TbOxp2b7K9hLbZ9B57c89CayIWaDuIPyQBVvmaUW01vANstdxVKFl+9vAtxupJCJm3pD1oupITyuizSYGXcDRS2hFtFUWAYyI0pR49zChFdFmBYZWea+XjYhWK6un1dAzB2f/zV/0vc03/PDBvrcJ8Ph5zzTSbrRTLg8johxm6B7RqSOhFdFm6WlFRElyeRgRZUloRURREloRUQo5l4cRUZrcPYyIkqSnFRFlKTC08hhPRFv5/8a1em29SFop6WFJY5KunOb3+ZLuqn7fLOmMKb+/StKzkv6y17kSWhFt1oflliXNAW4ALgSWA5dIWj7lsMuBfbaXAdcB10z5/VrgX+qUnNCKaDFN1Nt6OBcYs73D9gHgTmDVlGNWAbdWnzcAKyQJQNIfA48A2+rUnNCKiDpOlbRl0rZm0m+LgZ2Tvu+q9jHdMbYPAfuBUySdBPwV8Nd1C8lAfESb1R+I3217pIEKPgNcZ/vZquPVU0Iroq36N7n0MeD0Sd+XVPumO2aXpLnAImAP8CbgIkl/B5wMTEh63vb1RzpZQiuizfoTWvcCZ0paSiecVgPvn3LMKHAZ8APgIjrvTzXwe4cPkPQZ4NlugQUJrYh260No2T4kaS2wCZgD3Gx7m6R1wBbbo8BNwG2SxoC9dILtV5LQimgpUevOYC22NwIbp+y7etLn54GLe7TxmTrnSmhFtFUemI6I4iS0IqIoCa0yveL67/e9ze8ueXPf2wR46w9/0ki7ectPO+XyMCLKktCKiGK4f3cPZ1JCK6LN0tOKiJJkTCsiypLQiohi1FjgbxgltCJaSpR5eVhrEUBJJ0vaIOmnkrZLamYSUkTMqH6tET+T6va0Pg98zfZFko4HFjRYU0TMlCELpDp6hpakRcDbgD8FqNaAPtBsWRExIwoMrTqXh0uBp4BbJG2VdKOkhVMPkrTm8PrRB3mh74VGRJ/18RViM6lOaM0FzgG+aPts4Dngl95rZnu97RHbI/OY3+cyI6IRfXiF2EyrE1q7gF22N1ffN9AJsYgoXJ9eITajeoaW7Z8DOyW9ttq1Anio0aoiYkaUeHlY9+7hR4A7qjuHO4APNVdSRMyIIbz0q6NWaNl+AGjinWcRMUizNbQiYvYpdUZ8QiuixTRRXmoltCLaajaPaUXE7JTLw4goS0IrDnv1lT9opN0H1r6lkXa3Pv6FRtq94LQ3NNJu9Ed6WhFRloRWRBQjb+OJiJJknlZElMflpVZCK6LF0tOKiHJkcmlElCYD8RFRlIRWRJTDZCA+IsqSgfiIKEtCKyJKkcmlEVEWu8hFAOu8QiwiZqs+vfdQ0kpJD0sak/RL70WVNF/SXdXvmyWdUe1/p6T7JD1Y/e/be50roRXRYv14hZikOcANwIXAcuASScunHHY5sM/2MuA64Jpq/27gvbbPAi4DbutVc0Iroq0MTLje1t25wJjtHbYPAHcCq6Ycswq4tfq8AVghSba32n682r8NOFFS11fUJ7Qi2qz+5eGpkrZM2tZMamUxsHPS913VPqY7xvYhYD9wypRj/gS43/YL3UrOQHxEix3F3cPdtht796mk19G5ZHxXr2MTWhEt1qe7h48Bp0/6vqTaN90xuyTNBRYBewAkLQH+Cfig7f/odbJcHka0Vd1Lw965di9wpqSlko4HVgOjU44ZpTPQDnARcI9tSzoZ+Cpwpe3v1Sk7Pa3CvOL67zfS7gU3nN1Iu1fvuK+Rdj/4vcv73uayS7f2vc1h1plceuw9LduHJK0FNgFzgJttb5O0DthiexS4CbhN0hiwl06wAawFlgFXS7q62vcu208e6XwJrYg269MqD7Y3Ahun7Lt60ufngYun+bvPAp89mnMltCJarB89rZmW0Ipoq6xcGhFlKfPZw4RWRJvl8jAiipGXtUZEcQrsadWaXCrp45K2SfqJpC9JOqHpwiJiBvRpaZqZ1DO0JC0GPgqM2H49ncljq7v/VUSUQBMTtbZhUvfycC6dJSMOAguAx3scHxHDzvRtculM6tnTsv0Y8DngUeAJYL/tr089TtKaw8tWHKTryhIRMQSEkettw6TO5eHL6CzgtRQ4DVgo6QNTj7O93vaI7ZF5dF3DKyKGhV1vGyJ1BuLfATxi+ynbB4G7gbc0W1ZEzIgCQ6vOmNajwHmSFgD/A6wAtjRaVUQ0r9AxrZ6hZXuzpA3A/cAhYCuwvunCIqJ5w3ZnsI5adw9tfxr4dMO1RMSMGr5LvzoyIz6irUxCKyIKU97VYUIros2GbQ5WHQmtiDZLaEVEMWwYL+/6MKEVHQ39F3fdq89ppN0l753X9zZf/v2T+94mwCNP/1rf25xY+53+NJSeVkQUJaEVEcUwkDXiI6IcBmdMKyJKYTIQHxGFyZhWRBQloRUR5cgD0xFREgOzdWmaiJil0tOKiHLkMZ6IKInBmacVEUXJjPiIKErGtCKiGHbuHkZEYdLTiohyGI+PD7qIo5bQimirLE0TEcUpcMrDcYMuICIGw4AnXGvrRdJKSQ9LGpN05TS/z5d0V/X7ZklnTPrtqmr/w5Iu6HWuhFZEW7laBLDO1oWkOcANwIXAcuASScunHHY5sM/2MuA64Jrqb5cDq4HXASuBL1TtHVFCK6LFPD5ea+vhXGDM9g7bB4A7gVVTjlkF3Fp93gCskKRq/522X7D9CDBWtXdEjYxpPcO+3d/whv+qceipwO4mamhISfWWVCscbb2jG/pfwWjtI4fhn+1vHGsDz7Bv0ze84dSah58gacuk7+ttr68+LwZ2TvptF/CmKX//4jG2D0naD5xS7f/hlL9d3K2QRkLL9svrHCdpi+2RJmpoQkn1llQrlFVvSbV2Y3vloGv4VeTyMCKO1WPA6ZO+L6n2TXuMpLnAImBPzb/9fxJaEXGs7gXOlLRU0vF0BtanXmyPApdVny8C7rHtav/q6u7iUuBM4EfdTjboeVrrex8yVEqqt6Raoax6S6q1cdUY1VpgEzAHuNn2NknrgC22R4GbgNskjQF76QQb1XH/CDwEHAI+bLvryL9c4LNHEdFeuTyMiKIktCKiKAMLrV7T/oeFpNMlfUvSQ5K2Sbpi0DXVIWmOpK2SvjLoWrqRdLKkDZJ+Kmm7pDcPuqZuJH28+vfgJ5K+JOmEQdfUNgMJrZrT/ofFIeATtpcD5wEfHuJaJ7sC2D7oImr4PPA1278F/A5DXLOkxcBHgRHbr6cz6Lx6sFW1z6B6WnWm/Q8F20/Yvr/6/Ayd/1N1nbE7aJKWAO8Gbhx0Ld1IWgS8jc6dJWwfsP3fg62qp7nAidVcowXA4wOup3UGFVrTTfsf6iAAqJ5MPxvYPNhKevp74JPAsK87shR4CrilupS9UdLCQRd1JLYfAz4HPAo8Aey3/fXBVtU+GYivSdJJwJeBj9l+etD1HImk9wBP2r5v0LXUMBc4B/ii7bOB54BhHt98GZ0rgqXAacBCSR8YbFXtM6jQOuqp+4MkaR6dwLrD9t2DrqeH84H3SfpPOpfdb5d0+2BLOqJdwC7bh3uuG+iE2LB6B/CI7adsHwTuBt4y4JpaZ1ChVWfa/1Cols+4Cdhu+9pB19OL7atsL7F9Bp1/rvfYHsregO2fAzslvbbatYLOzOhh9ShwnqQF1b8XKxjiGwez1UAe4znStP9B1FLD+cClwIOSHqj2fcr2xgHWNJt8BLij+o/XDuBDA67niGxvlrQBuJ/OXeWt5JGeGZfHeCKiKBmIj4iiJLQioigJrYgoSkIrIoqS0IqIoiS0IqIoCa2IKMr/AnRFpmO6mf2tAAAAAElFTkSuQmCC\n",
            "text/plain": [
              "<Figure size 432x288 with 2 Axes>"
            ]
          },
          "metadata": {
            "needs_background": "light"
          }
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "IMBhKc0HeUkZ"
      },
      "source": [
        "As we discuss in Section 2, if we are interested in measuring a difference of costs with minimum variance, and each cost function is monotonic with respect to this ordering, it turns out that this coupling will always minimize the variance. However, this is only possible if we know the ordering in advance while building our causal mechanism. If we use a different order, we destroy this structure."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 282
        },
        "id": "5OZFg4uJe34X",
        "outputId": "ead2f385-baba-428f-9ef8-f08cefff5916"
      },
      "source": [
        "perm_inverse_cdf_p_q = coupling_util.permuted_inverse_cdf_coupling(\n",
        "    logits_1=p_logits,\n",
        "    logits_2=q_logits,\n",
        "    permutation_seed=1)\n",
        "plt.imshow(perm_inverse_cdf_p_q, vmin=0)\n",
        "plt.colorbar()"
      ],
      "execution_count": 18,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "<matplotlib.colorbar.Colorbar at 0x7f5b869ea810>"
            ]
          },
          "metadata": {},
          "execution_count": 18
        },
        {
          "output_type": "display_data",
          "data": {
            "image/png": "iVBORw0KGgoAAAANSUhEUgAAAS0AAAD4CAYAAAC5Z7DGAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAUOElEQVR4nO3dfbBdVX3G8e/DDYQ3C0haWpLYZATtRK1Fr4DYqjUKQS3pTGEaHC0inbQzgvhe6IxI0c5I64COpU5TCEOBKWikMxlNjQraF6Ux4aXQEKnXQEkAkZAUBITk3vv0j7Pj3F7uPWdfc/Y9Z+U8n5k1s8/ea6/9S4Afa6299t6yTUREKQ7odQARETORpBURRUnSioiiJGlFRFGStCKiKHOaaHToRYd5zryjut7u3Aef7XqbESV6jmfY7ee1L22c9ruH+YmdY7Xq3nHP8+ttL9uX63VLI0lrzryj+NVLz+96uy87946utxlRog2+dZ/beGLnGN9f/5JadYd+7Yfz9vmCXdJI0oqI/mdgnPFehzFjSVoRA8qYPa43POwnSVoRAyw9rYgohjFjBT7Gl6QVMcDGSdKKiEIYGCswadVaXCppmaT7JY1IuqjpoCJidozjWqWfdOxpSRoCrgLeBmwHNkpaa/u+poOLiOYY2FPgnFadntaJwIjtrbZ3AzcBy5sNKyKaZsxYzdJP6sxpzQe2Tfi9HThpciVJK4GVAENHH9mV4CKiQYax/spHtXTtgWnbq2wP2x4eetFh3Wo2IhrSWhFfr3TSad5b0hsl3SlpVNKZk46dI+mHVTmn07Xq9LQeBhZO+L2g2hcRRRNj7NMz161W6s17PwS8F/jopHNfDHwSGKaVR++ozt013fXq9LQ2AsdLWizpIGAFsLb+Hyki+lFrIl61Sgcd571tP2j7Hl7YcTsN+KbtnVWi+ibQ9m0SHXtatkclnQ+sB4aA1bY3dzovIvpba51W7Z7WPEmbJvxeZXtVtV1r3nsaU507v90JtRaX2l4HrKsZREQUYrxzL2qvHbaHm4ylrry5NGJA7e1p1Skd7Mu894zPTdKKGFBGjHFArdLBvsx7rwdOlXSUpKOAU6t900rSihhg41at0o7tUWDvvPcW4Eu2N0u6TNIZAJJeJ2k7cBbwd5I2V+fuBD5FK/FtBC6r9k0rD0xHDCgjdnuoO21NMe9t+5IJ2xtpDf2mOnc1sLrutZK0IgZUa3FpeYOtRpLW3AefbeQjFE+fVfcu6swc/uUNjbRbkscuOKWRdo/5wvcaaTe6oxuLS2dbeloRA8oWY05PKyIKMp6eVkSUojURX14KKC/iiOiKTMRHRHHG6j/G0zeStCIG1N4V8aVJ0ooYYOO5exgRpWg9MJ2kFRGFMGJPlx7jmU1JWhEDyiaLSyOiJMri0ogoh0lPKyIKk4n4iCiG6fyCv36UpBUxoFqfECsvBZQXcUR0SXc+1jrbkrQiBpTJiviIKEx6WhFRDFvpaUVEOVoT8XmMJyKKkXfEN66pr+aM37qwc6UZOmDptq632aR8NWfwtCbiM6cVEQXJiviIKEZWxEdEcfJhi4gohg17xpO0IqIQreFhklZEFCQr4iOiGKUueejYN5S0UNK3Jd0nabOkC2cjsIhoWmt4WKf0kzrRjAIfsb0EOBl4v6QlzYYVEbNhvHpPfKfSiaRlku6XNCLpoimOz5V0c3V8g6RF1f4DJV0n6V5JWyRd3OlaHZOW7Udt31lt/xTYAszv+KeIiL7Wuns4VKu0I2kIuAo4HVgCnD1Fx+Y8YJft44Argcur/WcBc22/Cngt8Cd7E9p0ZtTvqxo7AXjB8zSSVkraJGnTHp6fSbMR0QN7F5fWKR2cCIzY3mp7N3ATsHxSneXAddX2GmCpJNGaWjtM0hzgEGA38FS7i9VOWpIOB74CfND2Cxq1vcr2sO3hA5lbt9mI6KEZDA/n7e2UVGXlhGbmAxMftt3OC0djP69jexR4EjiaVgJ7BngUeAj4rO2d7WKudfdQ0oG0EtaNtm+pc05E9LcZ3j3cYXu4gTBOBMaAY4GjgH+T9C3bW6c7oc7dQwHXAFtsX9GtSCOi97p09/BhYOKrUhZU+6asUw0FjwCeAN4FfN32Hts/Ab4LtE2OdYaHbwDeA7xF0t1VeXuN8yKij9li1AfUKh1sBI6XtFjSQcAKYO2kOmuBc6rtM4HbbJvWkPAtAJIOo7VC4QftLtZxeGj736HAZbMR0VE3FpfaHpV0PrAeGAJW294s6TJgk+21tEZr10saAXbSSmzQuut4raTNtPLMtbbvaXe9rIiPGFDdXBFvex2wbtK+SyZsP0drecPk856ean87SVoRA6zEx3iStCIGVF4CGBHFqfOITr9J0qKZj1DsfN/ru94mwFOLG2mWRZ+4vZmGo2/ZMJqXAEZESTI8jIhiZE4rIorjJK2IKEkm4iOiGHbmtCKiKGIsdw8joiSZ04qIYpT6NZ4krYhB5da8VmmStCIGWO4eRkQxnIn4iChNhocRUZTcPYyIYthJWhFRmCx5iIiiZE4rIophxHjuHkZESQrsaCVpRQysTMRHRHEK7GolaUUMsPS04udevLqZr9vs+utmvvIT8NgFpzTS7jFf+F4j7e4rA+PjSVoRUQoD6WlFREmyTisiypKkFRHlUCbiI6IwBfa0ylvDHxHdYfC4apVOJC2TdL+kEUkXTXF8rqSbq+MbJC2acOw3Jd0uabOkeyUd3O5aSVoRA001S5sWpCHgKuB0YAlwtqQlk6qdB+yyfRxwJXB5de4c4AbgT22/AngzsKfd9WonLUlDku6S9NW650REn3PN0t6JwIjtrbZ3AzcByyfVWQ5cV22vAZZKEnAqcI/t/wSw/YTtsXYXm0lP60JgywzqR0S/q5+05knaNKGsnNDKfGDbhN/bq31MVcf2KPAkcDTwMsCS1ku6U9LHO4VcayJe0gLgHcBfAh+uc05E9LmZLS7dYXu4gSjmAL8NvA54FrhV0h22b53uhLo9rc8BHwfGp6sgaeXeLLyH52cQc0T0il2vdPAwsHDC7wXVvinrVPNYRwBP0OqV/avtHbafBdYBr2l3sY5JS9I7gZ/YvqNdPdurbA/bHj6QuZ2ajYh+MK56pb2NwPGSFks6CFgBrJ1UZy1wTrV9JnCbbQPrgVdJOrRKZm8C7mt3sTrDwzcAZ0h6O3Aw8EuSbrD97hrnRkQfUxfWadkelXQ+rQQ0BKy2vVnSZcAm22uBa4DrJY0AO2klNmzvknQFrcRnYJ3tr7W7XsekZfti4GIASW8GPpqEFbEfqHdnsF5T9jpaQ7uJ+y6ZsP0ccNY0595Aa9lDLVkRHzGwtP+/5cH2d4DvNBJJRMy+Ah/jSU8rYpBNux6gfyVpRQyqvAQwIkrTjbuHsy1JK2KQFZi08paHiChKelrAg5/q/hduFn2ima/xvPRjzbT7o4a+8tNUvE3o16/mNCnDw4goh6nziE7fSdKKGGTpaUVESTI8jIiyJGlFRFGStCKiFHKGhxFRmtw9jIiSpKcVEWVJ0oqIYmROKyKKk6QVESVRgS8BzFseIqIo6WlFDLIMDyOiGJmIj4jiJGlFRFGStCKiFKLMu4dJWhGDKnNaEVGcJK2IKEqSVssBL5/DIX9/TNfb/dmbHut6m9Dcl3NK0tRXc3a+r5mv/Lx4df6ZdUOGhxFRliStiCiGy7x7mGcPIwaZa5YOJC2TdL+kEUkXTXF8rqSbq+MbJC2adPwlkp6W9NFO10rSihhge98T36m0bUMaAq4CTgeWAGdLWjKp2nnALtvHAVcCl086fgXwz3ViTtKKGGTd6WmdCIzY3mp7N3ATsHxSneXAddX2GmCpJAFI+n3gAWBznZCTtCIGVd2E1Upa8yRtmlBWTmhpPrBtwu/t1T6mqmN7FHgSOFrS4cCfAX9RN+xMxEcMKDGjJQ87bA83EMalwJW2n646Xh3VSlqSjgSuBl5JK+++z3YWykQUrkvrtB4GFk74vaDaN1Wd7ZLmAEcATwAnAWdK+ivgSGBc0nO2/2a6i9XtaX0e+LrtMyUdBBxa87yI6GfdSVobgeMlLaaVnFYA75pUZy1wDnA7cCZwm20Dv7O3gqRLgafbJSyokbQkHQG8EXgvQDXRtrvenyUi+loXkpbtUUnnA+uBIWC17c2SLgM22V4LXANcL2kE2Ekrsf1C6vS0FgOPA9dKejVwB3Ch7WcmVqom5lYCHHLM4b9oPBExW7r4lgfb64B1k/ZdMmH7OeCsDm1cWudade4ezgFeA3zR9gnAM8ALFo/ZXmV72PbwQUceUufaEdFrXVpcOpvqJK3twHbbG6rfa2glsYgonMbrlX7SMWnZ/jGwTdLLq11LgfsajSoiZkU3VsTPtrp3Dy8AbqzuHG4Fzm0upIiYFX049KujVtKyfTfQxMKyiOil/TVpRcT+Z4Yr4vtGklbEANN4eVkrSStiUO3Pc1oRsX/K8DAiypKk1TJ+/2hjX86Jsjy1uJl2j7x1YedKM3TA0m2dK+1n0tOKiLIkaUVEMQr9Gk+SVsSAyjqtiCiPy8taSVoRAyw9rYgoRxaXRkRpMhEfEUVJ0oqIcphMxEdEWTIRHxFlSdKKiFJkcWlElMXOSwAjojDl5awkrYhBluFhRJTDQIaHEVGU8nJWklbEIMvwMCKKkruHEVGOvOWheesfubuRdk879rcaaTdg0Sdu73UItT191kmNtHv4lzc00u6+ai0u7U7WkrQM+DwwBFxt+zOTjs8F/gF4LfAE8Ie2H5T0NuAzwEHAbuBjtm9rd60DuhJxRJRpvGZpQ9IQcBVwOrAEOFvSkknVzgN22T4OuBK4vNq/A/g9268CzgGu7xRyklbEAJNdq3RwIjBie6vt3cBNwPJJdZYD11Xba4ClkmT7LtuPVPs3A4dUvbJpJWlFDCrPoMA8SZsmlJUTWpoPTPxo5PZqH1PVsT0KPAkcPanOHwB32n6+XdhFzWlFRDfN6NnDHbaHm4pE0itoDRlP7VQ3Pa2IQWbXK+09DEz85PeCat+UdSTNAY6gNSGPpAXAPwF/ZPtHnS6WpBUxqKqPtdYpHWwEjpe0WNJBwApg7aQ6a2lNtAOcCdxm25KOBL4GXGT7u3XCTtKKGGRd6GlVc1TnA+uBLcCXbG+WdJmkM6pq1wBHSxoBPgxcVO0/HzgOuETS3VX5lXbXqzWnJelDwB/TmpK7FzjX9nN1zo2IPtalxaW21wHrJu27ZML2c8BZU5z3aeDTM7lWx56WpPnAB4Bh26+ktXhsxUwuEhH9SePjtUo/qXv3cA6t9RN7gEOBRzrUj4h+ZzouHO1HHXtath8GPgs8BDwKPGn7G5PrSVq5dw3HHtous4iIPiDqLSzt1qM+3VJneHgUrdWsi4FjgcMkvXtyPdurbA/bHj6QtgtaI6JfdGfJw6yqc/fwrcADth+3vQe4BTil2bAiYlYUmLTqzGk9BJws6VDgZ8BSYFOjUUVE8wqd0+qYtGxvkLQGuBMYBe4CVjUdWEQ0r9/uDNZR6+6h7U8Cn2w4loiYVf039KsjD0xHDCqTpBURhSlvdJikFTHI+m0NVh1JWhGDLEkrIophw1h548NGkpbmzmVo0Uu73u5px3a9ycYc8i/HNNLuz970WCPtRnNfzfnva1/b9Tafv7RLXzlKTysiipKkFRHFMJAvTEdEOQzOnFZElMJkIj4iCpM5rYgoSpJWRJQjD0xHREkM7K+vpomI/VR6WhFRjjzGExElMTjrtCKiKFkRHxFFyZxWRBTDzt3DiChMeloRUQ7jsbFeBzFjSVoRgyqvpomI4hS45OGAXgcQEb1hwOOuVTqRtEzS/ZJGJF00xfG5km6ujm+QtGjCsYur/fdLOq3TtZK0IgaVq5cA1iltSBoCrgJOB5YAZ0taMqnaecAu28cBVwKXV+cuAVYArwCWAX9btTetJK2IAeaxsVqlgxOBEdtbbe8GbgKWT6qzHLiu2l4DLJWkav9Ntp+3/QAwUrU3rUbmtJ56/sc71t9/+f/UqDoP2NFEDA2pH+8bmw2khv3377b3Zhbre9c0EcOv72sDP2XX+m95zbya1Q+WtGnC71W2V1Xb84FtE45tB06adP7P69gelfQkcHS1/z8mnTu/XSCNJC3bv1ynnqRNtoebiKEJJcVbUqxQVrwlxdqO7WW9juEXkeFhROyrh4GFE34vqPZNWUfSHOAI4Ima5/4/SVoRsa82AsdLWizpIFoT62sn1VkLnFNtnwncZtvV/hXV3cXFwPHA99tdrNfrtFZ1rtJXSoq3pFihrHhLirVx1RzV+cB6YAhYbXuzpMuATbbXAtcA10saAXbSSmxU9b4E3AeMAu+33XbmXy7w2aOIGFwZHkZEUZK0IqIoPUtanZb99wtJCyV9W9J9kjZLurDXMdUhaUjSXZK+2utY2pF0pKQ1kn4gaYuk1/c6pnYkfaj69+C/JP2jpIN7HdOg6UnSqrnsv1+MAh+xvQQ4GXh/H8c60YXAll4HUcPnga/b/g3g1fRxzJLmAx8Ahm2/ktak84reRjV4etXTqrPsvy/YftT2ndX2T2n9R9V2xW6vSVoAvAO4utextCPpCFrPDlwDYHu37f/tbVQdzQEOqdYaHQo80uN4Bk6vktZUy/77OhEAVE+mnwBs6G0kHX0O+DjQ7+8dWQw8DlxbDWWvlnRYr4Oaju2Hgc8CDwGPAk/a/kZvoxo8mYivSdLhwFeAD9p+qtfxTEfSO4Gf2L6j17HUMAd4DfBF2ycAzwD9PL95FK0RwWLgWOAwSe/ubVSDp1dJa8ZL93tJ0oG0EtaNtm/pdTwdvAE4Q9KDtIbdb5F0Q29DmtZ2YLvtvT3XNbSSWL96K/CA7cdt7wFuAU7pcUwDp1dJq86y/75QvT7jGmCL7St6HU8nti+2vcD2Ilp/r7fZ7svegO0fA9skvbzatZTWyuh+9RBwsqRDq38vltLHNw72Vz15jGe6Zf+9iKWGNwDvAe6VdHe1789tr+thTPuTC4Abq/95bQXO7XE807K9QdIa4E5ad5XvIo/0zLo8xhMRRclEfEQUJUkrIoqSpBURRUnSioiiJGlFRFGStCKiKElaEVGU/wOPtg3xStdObwAAAABJRU5ErkJggg==\n",
            "text/plain": [
              "<Figure size 432x288 with 2 Axes>"
            ]
          },
          "metadata": {
            "needs_background": "light"
          }
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "TlgABLbQfDSG"
      },
      "source": [
        "## Independent couplings"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "nVlR1vSyfFeq"
      },
      "source": [
        "One other class of couplings that we compare against is the independent coupling, which implies that $p(x)$ and $q(y)$ have nothing in common. From a causal perspective, this corresponds to a situation where the outcome for an observation tells you nothing at all about what the outcome would have been for some other counterfactual intervention."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 282
        },
        "id": "0FjGrh_XfEiW",
        "outputId": "29eba140-28d5-48cc-ef1a-9851a2a34977"
      },
      "source": [
        "independent_pq = coupling_util.independent_coupling(\n",
        "    logits_1=p_logits,\n",
        "    logits_2=q_logits)\n",
        "plt.imshow(independent_pq, vmin=0)\n",
        "plt.colorbar()"
      ],
      "execution_count": 19,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "<matplotlib.colorbar.Colorbar at 0x7f5b86d7f650>"
            ]
          },
          "metadata": {},
          "execution_count": 19
        },
        {
          "output_type": "display_data",
          "data": {
            "image/png": "iVBORw0KGgoAAAANSUhEUgAAATQAAAD4CAYAAABi3BrkAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAUkElEQVR4nO3dfaxdVZ3G8e/Te3lrpRXbcYItM21Cx0llYtAGUSYmWtE6OpZkINbxhXGI/UNRfJkYMBETon+QGN8iw6QDOIgomIrxxqlWDZqJk1gpLxEKEq+A0IqRNwsy9uXe+8wfZ1dObu+9Z5eefc89az+fZId99ll79dfm3h9r7bXXWrJNREQJFg06gIiIfklCi4hiJKFFRDGS0CKiGEloEVGM0UYqXbzExy17Uf8rnup/lQBqoF41NHjcRKwANDTa3Vy8/a9STY34N1Dt/gN/4OChZ3UsdbzpdUv8xJOTtcre/osDO2xvPJY/bz40ktCOW/YiVv/rR/te7+j/9b3KTr1/6v9P3Oj+vldZ1dvML93IwWYyz6IDDcV7oP/xLjpY75f7qOud6H+sO3/xH8dcxxNPTvLzHX9Vq+zIqb9accx/4DxoJKFFxMJnYKqpbs+AJKFFtJQxh9xMq3RQktAiWiwttIgogjGThU19TEKLaLGpJoZgBygJLaKlDEwWltBqvVgraaOk+yWNS7q06aAiYn5M4VrHsOjZQpM0AlwFnAvsAW6TNGb73qaDi4jmGDhU2DO0Oi20s4Bx2w/YPgjcBGxqNqyIaJoxkzWPYVHnGdpK4JGuz3uAV00vJGkLsAVgdOkpfQkuIhpkmByeXFVL3yan295qe73t9aOLl/Sr2ohoSGemQL1jWNRpoe0FTuv6vKq6FhFDTUxyTPPbF5w6Ce02YK2kNXQS2WbgnxuNKiIa1xkUaFlCsz0h6WJgBzACXGd7d+ORRUSjOu+htSyhAdjeDmxvOJaImGdTbWuhRUSZWttCi4jyGDFZ2Cr8SWgRLZYuZ0QUwYiDHhl0GH2VhBbRUp0Xa9Pl7MmLYGJxE3MqmmoeN1FvU3NKmvo3aOoHu6n3zMv6RTxqffoxKG1QoOU/FRHtZYtJL6p19NJriTFJJ0i6ufp+p6TV1fVzJd0u6e7qv6/vuueV1fVxSV+S1DP7JqFFtNgUqnXMpWuJsTcD64B3SFo3rdhFwFO2Twc+D1xZXX8c+EfbfwdcCNzQdc/VwPuAtdXRc1/QJLSIluoMCozWOnqos8TYJuD66nwbsEGSbN9p+7fV9d3ASVVr7lRgqe2f2TbwVeC8XoEkoUW01OFBgToHsELSrq5jS1dVMy0xtnLaH/fnMrYngH3A8mll/gm4w/aBqvyeHnUeIaOcES02Wf89tMdtr28qDkkvo9MNfeOx1JOEFtFSfZwpUGeJscNl9kgaBZYBTwBIWgV8G3iP7V93lV/Vo84jpMsZ0WJTXlTr6OHPS4xJOp7OEmNj08qM0XnoD3A+cKttS3oh8N/Apbb/93Bh248CT0s6uxrdfA/wnV6BJKFFtFRncvqiWsec9XSeiR1eYuw+4Ju2d0u6QtLbqmLXAssljQMfBQ6/2nExcDpwuaS7quPF1XfvB64BxoFfA9/r9XdKlzOipYw41KepTzMtMWb78q7z/cAFM9z3aeDTs9S5CzjjaOJIQotoKZtaL80OkyS0iNbq/dLssElCi2gpkxZaRBQkCzxGRBGMssBjRJShs41dWSmgrL9NRByFdm40HBEFMtSZBTBUktAiWiwttIgogq200CKiDJ1Bgez6FBFFUF6srcOLYGJJU7seNaGJ5whNPZvIblIdTewmNTy/3O69X0jvOshGwxFRkMwUiIgiZKZARBQlO6dHRBFsODSVhBYRBeh0OZPQIqIQmSkQEUUo8bWNnu1NSadJ+rGkeyXtlnTJfAQWEU1Tv7axWzDqtNAmgI/ZvkPSycDtkn5o+96GY4uIhrVuT4Fqw89Hq/NnJN0HrASS0CKGWGeUs8VzOSWtBs4Eds7w3RZgC8DIKaf0IbSIaFKJL9bW7hxLegHwLeDDtp+e/r3trbbX214/smRJP2OMiIZMVVvZ9TqGRa0WmqTj6CSzG23f0mxIETEfShzl7JnQJAm4FrjP9ueaDyki5sswjWDWUaeFdg7wbuBuSXdV1z5he3tzYUVE02wx0baEZvunNLdYVkQMUOu6nBFRplY+Q4uIciWhRUQRSnwPLQktosWG6R2zOppJaCNm6gWTfa92gmGaptHUD0o2X+loYnSuiY1XoJFY+/DPasNEFniMiFKkyxkRRcgztIgoipPQIqIUGRSIiCLY5T1DK2uIIyKOgpicWlTr6FmTtFHS/ZLGJV06w/cnSLq5+n5ntbYikpZXS/z/UdKXp93zk6rOu6rjxb3iSAstosX68QxN0ghwFXAusAe4TdLYtGX6LwKesn26pM3AlcDbgf3AJ4EzqmO6d9reVTeWtNAiWurwXM46Rw9nAeO2H7B9ELgJ2DStzCbg+up8G7BBkmw/Wy2Asb8ff6cktIi2cuc5Wp0DWCFpV9expaumlcAjXZ/3VNeYqYztCWAfsLxGlF+pupufrNZmnFO6nBEtdhSjnI/bXt9kLDN4p+291W5z36KzLuNX57ohLbSIlnL/BgX2Aqd1fV5VXZuxjKRRYBnwxJzx2Xur/z4DfJ1O13ZOSWgRLXYUXc653AaslbRG0vHAZmBsWpkx4MLq/HzgVnv2miWNSlpRnR8HvBW4p1cg6XJGtFg/RjltT0i6GNgBjADX2d4t6Qpgl+0xOvuS3CBpHHiSTtIDQNJDwFLgeEnnAW8EfgPsqJLZCPAj4D97xZKEFtFSndZXf16srfYY2T7t2uVd5/uBC2a5d/Us1b7yaONIQotosdJmCiShRbRYjedjQyUJLaKljJjKAo8RUYrCGmhJaBGt1cdBgYUiCS2izQproiWhRbRYWmg1aMQct/RA3+s9xAl9rxOym1Sz9Q7TblJNPSDv/25S7kOoBqamktAiogQG0kKLiFLkPbSIKEcSWkSUQRkUiIiCpIUWEUUwOKOcEVGOshJa7bdZJI1IulPSd5sMKCLmkWseQ+JoXs+7BLivqUAiYgDamNAkrQLeAlzTbDgRMW8Ov1hb5xgSdZ+hfQH4OHDybAWqffq2AIyuWHbskUVE40p7sbZnC03SW4Hf2759rnK2t9peb3v9yNIlfQswIho0pXrHkKjTQjsHeJukfwBOBJZK+prtdzUbWkQ0TW1rodm+zPaqameWzXT200syixh2dQcEhijp5T20iNYargf+dRxVQrP9E+AnjUQSEfNviFpfdaSFFtFm/V97cqCS0CLaKgs8RkRJShvlTEKLaLPCElpZ2yZHRKs10kIbGZli+bJn+17vE32vsaOJ3aSGaycpyG5S0FysDbQb1J9Y0+WMiDKYoZrWVEcSWkSbpYUWEaVIlzMiypGEFhHFSEKLiBLI6XJGREkyyhkRpUgLLSLKUVhCy9SniLbyc8/Reh29SNoo6X5J45IuneH7EyTdXH2/U9Lq6vpyST+W9EdJX552zysl3V3d8yWp9/SIJLSINuvDEtySRoCrgDcD64B3SFo3rdhFwFO2Twc+D1xZXd8PfBL4txmqvhp4H7C2Ojb2+uskoUW0mKbqHT2cBYzbfsD2QeAmYNO0MpuA66vzbcAGSbL9rO2f0klsz8UlnQostf0z2wa+CpzXK5AktIioY4WkXV3Hlq7vVgKPdH3eU11jpjK2J4B9wPI5/ryVVT1z1XmEDApEtFn9QYHHba9vMJK+SAstoq36NyiwFzit6/Oq6tqMZSSNAsuYe0WwvVU9c9V5hCS0iDbrz76ctwFrJa2RdDyd/XvHppUZAy6szs+ns7/vrDXbfhR4WtLZ1ejme4Dv9AokXc6INuvDe2i2JyRdDOwARoDrbO+WdAWwy/YYcC1wg6Rx4Ek6SQ8ASQ8BS4HjJZ0HvNH2vcD7gf8CTgK+Vx1zSkKLaClRawSzFtvbge3Trl3edb4fuGCWe1fPcn0XcMbRxJGEFtFWmZweEUVJQouIYiSh9XbCyARrlj7ZRNWNaGI3qSZ2koLsJtVsvU39dvc/Vvfp/YR0OSOiHEloEVEE92+Uc6FIQotos7TQIqIUeYYWEeVIQouIItSbpzlUktAiWkqU1+Ws9TaLpBdK2ibpl5Luk/TqpgOLiOb1a0+BhaJuC+2LwPdtn18tD7K4wZgiYr4MUbKqo2dCk7QMeC3wLwDVmuEHmw0rIuZFYQmtTpdzDfAY8BVJd0q6RtKS6YUkbTm83viBp/YfWUtELCx93MZuoaiT0EaBVwBX2z4TeBY4Yt8921ttr7e9/oRTTuxzmBHRiP6sWLtg1Eloe4A9tndWn7fRSXARMeT6tI3dgtEzodn+HfCIpJdWlzYA9zYaVUTMi9K6nHVHOT8I3FiNcD4AvLe5kCJiXgxZd7KOWgnN9l3Agt+TLyKOUhsTWkSUp8SZAkloES2mqbIyWhJaRFu19RlaRJQpXc6IKEcSWm+LFx3k5Uv3NFH10GhiJynIblLPaWLXp6Z2qOp/1nCfQk0LLSLKkYQWEUXIrk8RUYq8hxYRZXFZGS0JLaLF0kKLiDLkxdqIKEkGBSKiGEloEVEGk0GBiChHBgUiohxJaBFRgrxYGxHlsItb4LHONnYRUao+7cspaaOk+yWNSzpi315JJ0i6ufp+p6TVXd9dVl2/X9Kbuq4/JOluSXdJ2lXnr5MWWkSL9aPLKWkEuAo4l84+vrdJGrPdvd3lRcBTtk+XtBm4Eni7pHXAZuBlwEuAH0n6G9uT1X2vs/143VjSQotoKwNTrnfM7Sxg3PYDtg8CNwGbppXZBFxfnW8DNkhSdf0m2wdsPwiMV/U9L0loEW1Wv8u5QtKurmNLVy0rgUe6Pu+prjFTGdsTwD5geY97DfxA0u3T/rxZpcsZ0WJH0eV83PZ8783797b3Snox8ENJv7T9P3PdkBZaRItpyrWOHvYCp3V9XlVdm7GMpFFgGZ2V6me91/bh//4e+DY1uqJJaBFtVbe72bsVdxuwVtIaScfTecg/Nq3MGHBhdX4+cKttV9c3V6Oga4C1wM8lLZF0MoCkJcAbgXt6BdJIl3PJyAFeveRXTVTdetl8pUlNbZLS/3rdh6ZI58XaYx/mtD0h6WJgBzACXGd7t6QrgF22x4BrgRskjQNP0kl6VOW+CdwLTAAfsD0p6S+Bb3fGDRgFvm77+71iyTO0iDbr02obtrcD26ddu7zrfD9wwSz3fgb4zLRrDwAvP9o4ktAiWqwfLbSFJAktoq2yYm1ElKO8uZxJaBFtli5nRBQhGw1HRFEKa6HVeptF0kck7ZZ0j6RvSDqx6cAiYh70afmghaJnQpO0EvgQsN72GXRenNvcdGAR0TxNTdU6hkXdLucocJKkQ8Bi4LfNhRQR88L07cXahaJnC62aIPpZ4GHgUWCf7R9MLydpy+GlRfY9MTn964hYYISR6x3Dok6X8xQ6i7CtobOi5BJJ75pezvZW2+ttr1+2fJjm70W0mF3vGBJ1BgXeADxo+zHbh4BbgNc0G1ZEzIvCElqdZ2gPA2dLWgz8CdgA1NqwICIWsAKfofVMaLZ3StoG3EFneY87ga1NBxYRzRumEcw6ao1y2v4U8KmGY4mIeTVc3ck6MlMgoq1MElpEFKSsHmcSWkSbDdM7ZnUkoUW0WRJaRBTBhsmy+pyNJLSTBa9tZD2O7CTVlOwm1aQGdpPq1waUaaFFRDGS0CKiCAayp0BElMHgPEOLiBKYDApEREHyDC0iipGEFhFlyOT0iCiFgTYuHxQRhUoLLSLKkKlPEVEKg/MeWkQUIzMFIqIYeYYWEUWwM8oZEQVJCy0iymA8OTnoIPoqCS2irbJ8UEQUpbDXNvq1kG9EDBkDnnKtoxdJGyXdL2lc0qUzfH+CpJur73dKWt313WXV9fslvalunTNJQotoK1cLPNY55iBpBLgKeDOwDniHpHXTil0EPGX7dODzwJXVveuAzcDLgI3Av0saqVnnEZLQIlrMk5O1jh7OAsZtP2D7IHATsGlamU3A9dX5NmCDJFXXb7J9wPaDwHhVX506j9DIM7Tbf3Hg8ZFTx39To+gK4PEmYjg643ULLpB4axmmWGG44l0Isf71sVbwDE/t+JG3rahZ/ERJu7o+b7W9tTpfCTzS9d0e4FXT7v9zGdsTkvYBy6vrP5t278rqvFedR2gkodn+izrlJO2yvb6JGJowTPEOU6wwXPEOU6xzsb1x0DH0W7qcEXGs9gKndX1eVV2bsYykUWAZne1gZ7u3Tp1HSEKLiGN1G7BW0hpJx9N5yD82rcwYcGF1fj5wq21X1zdXo6BrgLXAz2vWeYRBv4e2tXeRBWWY4h2mWGG44h2mWBtXPRO7GNgBjADX2d4t6Qpgl+0x4FrgBknjwJN0EhRVuW8C9wITwAdsTwLMVGevWOTC5nJFRHulyxkRxUhCi4hiDCyhPZ9pDYMg6TRJP5Z0r6Tdki4ZdEx1VG9b3ynpu4OOZS6SXihpm6RfSrpP0qsHHdNcJH2k+jm4R9I3JJ046JjiOQNJaM93WsOATAAfs70OOBv4wAKOtdslwH2DDqKGLwLft/23wMtZwDFLWgl8CFhv+ww6D6s3Dzaq6DaoFtrzmtYwCLYftX1Hdf4MnV+4lXPfNViSVgFvAa4ZdCxzkbQMeC2dETBsH7T9h8FG1dMocFL1LtVi4LcDjie6DCqhzTRVYkEnCYBqhYAzgZ2DjaSnLwAfBxb62jBrgMeAr1Td42skLRl0ULOxvRf4LPAw8Ciwz/YPBhtVdMugQE2SXgB8C/iw7acHHc9sJL0V+L3t2wcdSw2jwCuAq22fCTwLLOTnqafQ6UmsAV4CLJH0rsFGFd0GldCe17SGQZF0HJ1kdqPtWwYdTw/nAG+T9BCdrvzrJX1tsCHNag+wx/bhFu82OgluoXoD8KDtx2wfAm4BXjPgmKLLoBLa85rWMAjVEifXAvfZ/tyg4+nF9mW2V9leTeff9VbbC7IVYft3wCOSXlpd2kDnjfGF6mHgbEmLq5+LDSzgQYw2GsjUp9mmSgwilhrOAd4N3C3pruraJ2xvH2BMJfkgcGP1P7YHgPcOOJ5Z2d4paRtwB53R7zvJNKgFJVOfIqIYGRSIiGIkoUVEMZLQIqIYSWgRUYwktIgoRhJaRBQjCS0iivH/PwksRDHtE8sAAAAASUVORK5CYII=\n",
            "text/plain": [
              "<Figure size 432x288 with 2 Axes>"
            ]
          },
          "metadata": {
            "needs_background": "light"
          }
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "0aYYUWF2StmY"
      },
      "source": [
        "## Gadget 1 and Gadget 2"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "lLpQIZ3XT3wj"
      },
      "source": [
        "We now show how to use our learnable \"gadgets\" to define couplings and to draw samples from a counterfactual.\n",
        "\n",
        "As discussed in Section 5.1, Gadget 1 deviates from a normal SCM in that the exogenous noise is not shared exactly between the \"observed\" and \"counterfactual\" samples. In particular, sampling from a counterfactual distribution requires transposing the matrix of Gumbels, and thus requires designating one of the interventions as the non-transposed original and the other as the transposed counterfactual.\n",
        "\n",
        "Gadget 2, on the other hand, satisfies the normal requirements of an SCM, and uses the same exogenous noise across all possible interventions. This means that it can be used in the same set of situations as the Gumbel-max SCM."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "hX7dqgEIXAd8"
      },
      "source": [
        "Both gadgets are implemented as `flax` modules, which separate the definition of the model class $\\{f_\\theta\\}_{\\theta \\in \\Theta}$ from the specific value of the parameters $\\theta$. We can instantiate each model class by specifying all of the necessary hyperparameters. For instance:"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "XHSIMRWoR_-E"
      },
      "source": [
        "# S_dim is the number of outcomes for our distribution of interest.\n",
        "gadget_1_def = gadget_1.GadgetOneMLPPredictor(\n",
        "    S_dim=10, hidden_features=[1024, 1024],\n",
        "    relaxation_temperature=1.0)\n",
        "\n",
        "# Gadget 2 also requires Z_dim, the space of the latent auxiliary variable.\n",
        "gadget_2_def = gadget_2.GadgetTwoMLPPredictor(\n",
        "    S_dim=10, Z_dim=100, hidden_features=[1024, 1024],\n",
        "    relaxation_temperature=1.0, learn_prior=False)"
      ],
      "execution_count": 20,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "o9_LiFFPYKUE"
      },
      "source": [
        "To use them to draw samples, we must pick a particular value for $\\theta$. We can start by randomly initializing each:"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "3X29TDQDYERV",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "528a18a7-26b3-4225-9835-c44fb764955e"
      },
      "source": [
        "init_key = jax.random.PRNGKey(1001)\n",
        "gadget_1_theta = gadget_1_def.init(init_key, jnp.zeros([gadget_1_def.S_dim]))\n",
        "\n",
        "init_key = jax.random.PRNGKey(1002)\n",
        "gadget_2_theta = gadget_2_def.init(init_key, jnp.zeros([gadget_2_def.S_dim]))\n",
        "\n",
        "# Summarize the shape of each parameter tree:\n",
        "print(\"Gadget 1:\")\n",
        "print(jax.tree_map(lambda x: f\"dtype={x.dtype} shape={x.shape} values={x.reshape([-1])[:4]}...\", gadget_1_theta))\n",
        "print(\"Gadget 2:\")\n",
        "print(jax.tree_map(lambda x: f\"dtype={x.dtype} shape={x.shape} values={x.reshape([-1])[:4]}...\", gadget_2_theta))"
      ],
      "execution_count": 21,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Gadget 1:\n",
            "FrozenDict({\n",
            "    params: {\n",
            "        hidden_layers_0: {\n",
            "            bias: 'dtype=float32 shape=(1024,) values=[0. 0. 0. 0.]...',\n",
            "            kernel: 'dtype=float32 shape=(10, 1024) values=[-0.15586407  0.52728695  0.4646642  -0.02859419]...',\n",
            "        },\n",
            "        hidden_layers_1: {\n",
            "            bias: 'dtype=float32 shape=(1024,) values=[0. 0. 0. 0.]...',\n",
            "            kernel: 'dtype=float32 shape=(1024, 1024) values=[-0.0640266   0.04275422  0.02152636  0.04715558]...',\n",
            "        },\n",
            "        output_layer: {\n",
            "            bias: 'dtype=float32 shape=(10, 10) values=[0. 0. 0. 0.]...',\n",
            "            kernel: 'dtype=float32 shape=(1024, 10, 10) values=[-0.0507848  -0.00499428 -0.03920214  0.0249158 ]...',\n",
            "        },\n",
            "    },\n",
            "})\n",
            "Gadget 2:\n",
            "FrozenDict({\n",
            "    params: {\n",
            "        hidden_layers_0: {\n",
            "            bias: 'dtype=float32 shape=(1024,) values=[0. 0. 0. 0.]...',\n",
            "            kernel: 'dtype=float32 shape=(10, 1024) values=[-0.33588833  0.00931741  0.5640106  -0.33654755]...',\n",
            "        },\n",
            "        hidden_layers_1: {\n",
            "            bias: 'dtype=float32 shape=(1024,) values=[0. 0. 0. 0.]...',\n",
            "            kernel: 'dtype=float32 shape=(1024, 1024) values=[-0.00269697  0.03676963 -0.01132549 -0.03541469]...',\n",
            "        },\n",
            "        output_layer: {\n",
            "            bias: 'dtype=float32 shape=(100, 10) values=[0. 0. 0. 0.]...',\n",
            "            kernel: 'dtype=float32 shape=(1024, 100, 10) values=[ 0.01256981 -0.00283486 -0.01371078 -0.02073329]...',\n",
            "        },\n",
            "    },\n",
            "})\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "z7EkURiaZSdz"
      },
      "source": [
        "We can also bind a particular value of $\\theta$ to each model definition to obtain a concrete mechanism $f_\\theta$. (This is only recommended for interactive use cases, such as this notebook. If you want to learn $\\theta$, it's better to keep the two separate. See the [flax documentation](https://flax.readthedocs.io/en/latest/notebooks/flax_basics.html) for more details on using flax.)"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "pzIYrUC_YFYb"
      },
      "source": [
        "gadget_1_at_init = gadget_1_def.bind(gadget_1_theta)\n",
        "gadget_2_at_init = gadget_2_def.bind(gadget_2_theta)"
      ],
      "execution_count": 22,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "wnKpasOTklYg"
      },
      "source": [
        "### Sampling from the gadgets"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "9QMn6ZmKZ-0X"
      },
      "source": [
        "Given a bound gadget, we can draw samples similarly to Gumbel-max. Each gadget defines a method `sample`, which can be used to draw samples according to their structural causal model. Just like for Gumbel-max, using the same random number generator for two different logit vectors produces coupled interventions. However, as noted before, Gadget 1 requires passing a special `transpose` argument when sampling the second logit vector."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "8tmJCr0jZ95b",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "e1c869a7-e41a-4307-e30a-cb84c34e1d5a"
      },
      "source": [
        "keys = jax.random.split(jax.random.PRNGKey(42), 20)\n",
        "g1_p_samples = []\n",
        "g1_q_samples = []\n",
        "g2_p_samples = []\n",
        "g2_q_samples = []\n",
        "for prng_key in keys:\n",
        "  # Gadget 1\n",
        "  g1_p_samples.append(int(gadget_1_at_init.sample(p_logits, prng_key)))\n",
        "  g1_q_samples.append(int(gadget_1_at_init.sample(q_logits, prng_key, transpose=True)))\n",
        "  # Gadget 2\n",
        "  g2_p_samples.append(int(gadget_2_at_init.sample(p_logits, prng_key)))\n",
        "  g2_q_samples.append(int(gadget_2_at_init.sample(q_logits, prng_key)))\n",
        "\n",
        "print(\"g1_p_samples\", g1_p_samples)\n",
        "print(\"g1_q_samples\", g1_q_samples)\n",
        "print()\n",
        "print(\"g2_p_samples\", g2_p_samples)\n",
        "print(\"g2_q_samples\", g2_q_samples)"
      ],
      "execution_count": 23,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "g1_p_samples [2, 1, 2, 6, 1, 1, 3, 4, 9, 2, 7, 4, 8, 5, 0, 6, 2, 9, 5, 4]\n",
            "g1_q_samples [8, 5, 8, 6, 5, 4, 9, 1, 4, 0, 6, 1, 2, 1, 7, 2, 2, 2, 6, 5]\n",
            "\n",
            "g2_p_samples [4, 0, 8, 7, 6, 0, 7, 3, 6, 6, 2, 1, 7, 7, 2, 4, 6, 3, 2, 8]\n",
            "g2_q_samples [4, 0, 1, 7, 1, 0, 1, 3, 1, 6, 2, 1, 2, 2, 2, 4, 6, 3, 2, 1]\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 369
        },
        "id": "grMxLIavfsN6",
        "outputId": "a96db984-575a-48f5-90af-495319c85b53"
      },
      "source": [
        "g1_init_pq = coupling_util.joint_from_samples(\n",
        "    coupling_util.sampler_from_common_random_numbers(gadget_1_at_init.sample, second_kwargs={\"transpose\": True}),\n",
        "    logits_1=p_logits,\n",
        "    logits_2=q_logits,\n",
        "    rng=jax.random.PRNGKey(42),\n",
        "    num_samples=100_000)\n",
        "g2_init_pq = coupling_util.joint_from_samples(\n",
        "    coupling_util.sampler_from_common_random_numbers(gadget_2_at_init.sample),\n",
        "    logits_1=p_logits,\n",
        "    logits_2=q_logits,\n",
        "    rng=jax.random.PRNGKey(42),\n",
        "    num_samples=100_000)\n",
        "\n",
        "_, axs = plt.subplots(ncols=2, figsize=(12,6))\n",
        "axs[0].imshow(g1_init_pq, vmin=0)\n",
        "axs[1].imshow(g2_init_pq, vmin=0)"
      ],
      "execution_count": 24,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "<matplotlib.image.AxesImage at 0x7f5b864abe10>"
            ]
          },
          "metadata": {},
          "execution_count": 24
        },
        {
          "output_type": "display_data",
          "data": {
            "image/png": "iVBORw0KGgoAAAANSUhEUgAAArkAAAFPCAYAAABJfdYtAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAATLklEQVR4nO3df6zdd33f8df73mvHjvPDoZDR/KDJBgNlSBvUY4FMTApM0JaVfyYtqDCtahVNWssPVUJ0fwypf05VRf/oKmWETlqz8kfKpBYh6NaCpk0o1AlokDiZUID8rEhWQkLi+Nf97I9rSykj9rn2PefjvO/jIUWKr4/P+/O1r99+nuPje2qMEQAA6GRt9gEAAGCniVwAANoRuQAAtCNyAQBoR+QCANDOxjLudO/GpWP/3oPLuOuzOnnp+spnnjFmjd6cNDdJTfrCHGPiQ7M9Pzo1Ze7m3okXPelzrCZ95ZcXX3wmx088X1OGT7L/4CXjimsOrHzu0SMrHwk082Kez/Fx7Cfu7KVE7v69B3PzG391GXd9Vk+/dfVhfcbJA3P+TFw/Ou9LwK2dnDP31L45c5Pk6q/+YMrco9ddPmVukqwfm1O56y/OeUDxtW/8hylzZ7rimgP5F3e9Z+Vzj/zspCUCtHHP+POX/T4vVwAAoB2RCwBAOyIXAIB2RC4AAO2IXAAA2hG5AAC0I3IBAGhH5AIA0I7IBQCgHZELAEA7C0VuVb23qh6qqm9X1SeWfSgAzp+dDbBA5FbVepLfS/JzSW5K8oGqumnZBwNg++xsgC2LPJP7tiTfHmM8PMY4nuSzSd6/3GMBcJ7sbIAsFrnXJnn0Jd9+7PTHALj42NkA2cF/eFZVt1fV4ao6fPzkCzt1twAswUt39tEfHJt9HIAdt0jkPp7k+pd8+7rTH/sbxhh3jDEOjTEO7d24dKfOB8D2bHtn77/qkpUdDmBVFoncv0zyhqq6sar2JrktyZ8s91gAnCc7GyDJxrluMMY4WVW/luRLSdaTfGaMcf/STwbAttnZAFvOGblJMsb4QpIvLPksAOwAOxvAO54BANCQyAUAoB2RCwBAOyIXAIB2RC4AAO2IXAAA2hG5AAC0I3IBAGhH5AIA0M5C73i2XcevXM8jP3/VMu76rF543cmVzzzjxv+6OWXuC1cv5ZdwIc+9bs5jpKseOjVlbpIcf82BKXP3PHtiytwk2Xju2JS5m/v3TJmbMWfsTEcfTB78R6uf+yv/5zurH3ranX/3xmmzgdXwTC4AAO2IXAAA2hG5AAC0I3IBAGhH5AIA0I7IBQCgHZELAEA7IhcAgHZELgAA7YhcAADaEbkAALQjcgEAaEfkAgDQjsgFAKAdkQsAQDsiFwCAdkQuAADtiFwAANoRuQAAtCNyAQBoR+QCANCOyAUAoB2RCwBAOyIXAIB2RC4AAO2IXAAA2hG5AAC0I3IBAGhH5AIA0I7IBQCgnY1l3OnmnuToT59axl2f1U/dt77ymWcc/ak5jxc2jo4pc5Pk+ME5s188OO+x2am9e6bMHWtz5ibJJc/unTP36WNT5u5GVWupSy5Z+dzP3PSGlc8841Pf/R9T5n70hndMmQu7kWdyAQBoR+QCANCOyAUAoB2RCwBAOyIXAIB2RC4AAO2IXAAA2hG5AAC0I3IBAGhH5AIA0I7IBQCgnXNGblVdX1VfrqoHqur+qvrIKg4GwPbZ2QBbNha4zckkvzHGuK+qLk9yb1X9tzHGA0s+GwDbZ2cDZIFncscYT44x7jv9/88lOZLk2mUfDIDts7MBtmzrNblVdUOStyS5ZxmHAWDn2NnAbrbIyxWSJFV1WZI/TvLRMcazP+H7b09ye5KsX3Vwxw4IwPZtZ2fvqwMrPh3A8i30TG5V7cnWsrxrjPG5n3SbMcYdY4xDY4xD65ddtpNnBGAbtruz99a+1R4QYAUW+eoKleTOJEfGGL+z/CMBcL7sbIAtizyTe0uSDyW5taq+cfq/n1/yuQA4P3Y2QBZ4Te4Y438mqRWcBYALZGcDbPGOZwAAtCNyAQBoR+QCANCOyAUAoB2RCwBAOyIXAIB2RC4AAO2IXAAA2hG5AAC0c853PDu/e93M+quPLeWuz+b/vn05l7OIg4f3Tpn7wrVjytwkec29c2bX5rxrnmX9xLxrPn7ZnMfCe5+d9Bh8N75XWFVqz7z9OcPHXv/OKXPvfOQrU+Ymya+87h9Pmw0zeCYXAIB2RC4AAO2IXAAA2hG5AAC0I3IBAGhH5AIA0I7IBQCgHZELAEA7IhcAgHZELgAA7YhcAADaEbkAALQjcgEAaEfkAgDQjsgFAKAdkQsAQDsiFwCAdkQuAADtiFwAANoRuQAAtCNyAQBoR+QCANCOyAUAoB2RCwBAOyIXAIB2RC4AAO2IXAAA2hG5AAC0I3IBAGhH5AIA0M7GUu715FpOPbVvKXd9Nv/61j9f+cwz/tPlN0+Ze/kXL5syN0mef+2cx0hrJ6aMTZK86qFj84ZPUifnzD121Z4pczc3asrcqdbWUgcOrH7uiy+ufuZkv/p3bp02+/OPf3XK3Pdd+7NT5oJncgEAaEfkAgDQjsgFAKAdkQsAQDsiFwCAdkQuAADtiFwAANoRuQAAtCNyAQBoR+QCANCOyAUAoJ2FI7eq1qvq61X1+WUeCIALZ2cDu912nsn9SJIjyzoIADvKzgZ2tYUit6quS/ILST693OMAcKHsbIDFn8n9VJKPJ9l8uRtU1e1VdbiqDp/60fM7cjgAzsu2dvbxzaOrOxnAipwzcqvqfUm+P8a492y3G2PcMcY4NMY4tH7ZgR07IACLO5+dvXdt/4pOB7A6izyTe0uSX6yq7yb5bJJbq+oPl3oqAM6XnQ2QBSJ3jPGbY4zrxhg3JLktyV+MMT649JMBsG12NsAWXycXAIB2NrZz4zHGV5J8ZSknAWBH2dnAbuaZXAAA2hG5AAC0I3IBAGhH5AIA0I7IBQCgHZELAEA7IhcAgHZELgAA7YhcAADa2dY7ni3qtVc8k4//0z9dxl2f1b8//J6VzzxjvLCUn8pzOvmGKWOTJGsnxpS5r/7mnLlJcvQ1e6bMPXZFTZmbJJc/dnLK3JN716fMTc37uZ5mbS3jwP6Vj621ic+zrE36/Fp7cc7cJP/shrdPmfulJ742ZW6SvOeafzBtNvN5JhcAgHZELgAA7YhcAADaEbkAALQjcgEAaEfkAgDQjsgFAKAdkQsAQDsiFwCAdkQuAADtiFwAANoRuQAAtCNyAQBoR+QCANCOyAUAoB2RCwBAOyIXAIB2RC4AAO2IXAAA2hG5AAC0I3IBAGhH5AIA0I7IBQCgHZELAEA7IhcAgHZELgAA7YhcAADaEbkAALQjcgEAaEfkAgDQzsYy7vTpY5flM995xzLu+qzG8XnNvv/VL0yZe3Tz0ilzk+Tgt5by6XNOz11XU+YmycH3PDll7tp/vHrK3CR59nVzfp03XpwyNmMXPvQf62vZvHL1u2Rtfd5P9rwtsvu892feNm32f370y1Pmfuj6W6bM5W/ahescAIDuRC4AAO2IXAAA2hG5AAC0I3IBAGhH5AIA0I7IBQCgHZELAEA7IhcAgHZELgAA7YhcAADaWShyq+pgVd1dVQ9W1ZGqevuyDwbA+bGzAZKNBW/3u0m+OMb451W1N8mlSzwTABfGzgZ2vXNGblVdmeSdSf5Vkowxjic5vtxjAXA+7GyALYu8XOHGJE8l+YOq+npVfbqqDvz4jarq9qo6XFWHTz77wo4fFICFbHtnnzj5/OpPCbBki0TuRpK3Jvn9McZbkjyf5BM/fqMxxh1jjENjjEMbV/ibMYBJtr2z92z8fw0M8Iq3SOQ+luSxMcY9p799d7YWKAAXHzsbIAtE7hjjr5I8WlVvPP2hdyV5YKmnAuC82NkAWxb96gq/nuSu0/9K9+Ekv7y8IwFwgexsYNdbKHLHGN9IcmjJZwFgB9jZAN7xDACAhkQuAADtiFwAANoRuQAAtCNyAQBoR+QCANCOyAUAoB2RCwBAOyIXAIB2Fn1b3225Zt8z+eQb/3QZd31WH/7rD6x85hl/68rnpsx9/MT6lLlJcuxVS/n0OacTl48pc5PkxFd/esrcuu3ZKXOT5Oo790+Ze2r/nMfgdWrK2LnWklP7Vv/7edS+lc88Y9bmrLWaNHmizc1po//ljf9kytx/9/DXpsxNkt/622+dNvti45lcAADaEbkAALQjcgEAaEfkAgDQjsgFAKAdkQsAQDsiFwCAdkQuAADtiFwAANoRuQAAtCNyAQBoR+QCANCOyAUAoB2RCwBAOyIXAIB2RC4AAO2IXAAA2hG5AAC0I3IBAGhH5AIA0I7IBQCgHZELAEA7IhcAgHZELgAA7YhcAADaEbkAALQjcgEAaEfkAgDQjsgFAKAdkQsAQDsby7jTU1nLc6f2L+Ouz2pj76mVzzxjrcaUue9+/UNT5ibJF4/9vSlz65k9U+YmyeaJOXMP/PfL5wxO8sQ758w9eWBzytwTX5sydqqxXjl+5ep/X+35Ua185hm1OWdnr02amyRVk36+N+f92TzLb73+H06b/UsPfm/K3LvedN2UuWfjmVwAANoRuQAAtCNyAQBoR+QCANCOyAUAoB2RCwBAOyIXAIB2RC4AAO2IXAAA2hG5AAC0I3IBAGhnocitqo9V1f1V9a2q+qOq2rfsgwFwfuxsgAUit6quTfLhJIfGGG9Osp7ktmUfDIDts7MBtiz6coWNJPuraiPJpUmeWN6RALhAdjaw650zcscYjyf57SSPJHkyyQ/HGH/247erqtur6nBVHX7ur0/u/EkBOKfz2dknjj+/6mMCLN0iL1e4Ksn7k9yY5JokB6rqgz9+uzHGHWOMQ2OMQ5e/amPnTwrAOZ3Pzt6z98CqjwmwdIu8XOHdSb4zxnhqjHEiyeeSvGO5xwLgPNnZAFksch9JcnNVXVpVleRdSY4s91gAnCc7GyCLvSb3niR3J7kvyTdP/5g7lnwuAM6DnQ2wZaEXz44xPpnkk0s+CwA7wM4G8I5nAAA0JHIBAGhH5AIA0I7IBQCgHZELAEA7IhcAgHZELgAA7YhcAADaEbkAALSz0Duebder1k7ltst/sIy7Pqvv3fTVlc8844cn90+Z+8Lm3ilzk2ScmPMY6Zo3fX/K3CR54qGrp8w9tX/e49FT+zenzH3t/6opc5/+0ZSxU22uV45dOeFzrJbyR9CCs+eMnXjFWRtjytzat2/K3CTJ5pxrXjt1asrcJPkvb75hytw33zvnz4r//Usv/32eyQUAoB2RCwBAOyIXAIB2RC4AAO2IXAAA2hG5AAC0I3IBAGhH5AIA0I7IBQCgHZELAEA7IhcAgHZELgAA7YhcAADaEbkAALQjcgEAaEfkAgDQjsgFAKAdkQsAQDsiFwCAdkQuAADtiFwAANoRuQAAtCNyAQBoR+QCANCOyAUAoB2RCwBAOyIXAIB2RC4AAO2IXAAA2hG5AAC0U2OMnb/TqqeSfO88f/irkzy9g8d5JXDN/e22601eudf8M2OM18w+xCrZ2dvmmncH1/zK8LI7eymReyGq6vAY49Dsc6ySa+5vt11vsjuveTfajb/Ornl3cM2vfF6uAABAOyIXAIB2LsbIvWP2ASZwzf3ttutNduc170a78dfZNe8OrvkV7qJ7TS4AAFyoi/GZXAAAuCAiFwCAdi6ayK2q91bVQ1X17ar6xOzzLFtVXV9VX66qB6rq/qr6yOwzrUpVrVfV16vq87PPsgpVdbCq7q6qB6vqSFW9ffaZlq2qPnb68/pbVfVHVbVv9pnYebtpb9vZdnZnXXf2RRG5VbWe5PeS/FySm5J8oKpumnuqpTuZ5DfGGDcluTnJv9kF13zGR5IcmX2IFfrdJF8cY7wpyd9P82uvqmuTfDjJoTHGm5OsJ7lt7qnYabtwb9vZu4ed3WRnXxSRm+RtSb49xnh4jHE8yWeTvH/ymZZqjPHkGOO+0///XLZ+E10791TLV1XXJfmFJJ+efZZVqKork7wzyZ1JMsY4PsZ4Zu6pVmIjyf6q2khyaZInJp+Hnber9radbWc313JnXyyRe22SR1/y7ceyC5bHGVV1Q5K3JLln7klW4lNJPp5kc/ZBVuTGJE8l+YPTf9336ao6MPtQyzTGeDzJbyd5JMmTSX44xvizuadiCXbt3razW7OzG+3siyVyd62quizJHyf56Bjj2dnnWaaqel+S748x7p19lhXaSPLWJL8/xnhLkueTdH/t4lXZekbvxiTXJDlQVR+ceyrYGXZ2e3Z2o519sUTu40muf8m3rzv9sdaqak+2luVdY4zPzT7PCtyS5Ber6rvZ+qvNW6vqD+ceaekeS/LYGOPMMz53Z2uBdvbuJN8ZYzw1xjiR5HNJ3jH5TOy8Xbe37Ww7u6m2O/tiidy/TPKGqrqxqvZm6wXPfzL5TEtVVZWt1/wcGWP8zuzzrMIY4zfHGNeNMW7I1q/xX4wxWjxafDljjL9K8mhVvfH0h96V5IGJR1qFR5LcXFWXnv48f1ea/8ONXWpX7W07285urO3O3ph9gCQZY5ysql9L8qVs/au+z4wx7p98rGW7JcmHknyzqr5x+mP/dozxhYlnYjl+Pcldp0Pg4SS/PPk8SzXGuKeq7k5yX7b+RfrX0+ytItmVe9vO3j3s7CY729v6AgDQzsXycgUAANgxIhcAgHZELgAA7YhcAADaEbkAALQjcgEAaEfkAgDQzv8DtXnZt069iJ8AAAAASUVORK5CYII=\n",
            "text/plain": [
              "<Figure size 864x432 with 2 Axes>"
            ]
          },
          "metadata": {
            "needs_background": "light"
          }
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "mHnralYtcaTW"
      },
      "source": [
        "(Note: Gadget 2, even at initialization, has similar behavior to Gumbel-max, in that it tends to produce samples that are the same across $p$ and $q$. Gadget 1, on the other hand, often draws distinct samples at initialization, because the exogenous noise is transposed.)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "4T06poU2ktUg"
      },
      "source": [
        "### Drawing counterfactual samples"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "7KFb_eprku7Y"
      },
      "source": [
        "Each gadget also provides a method `gadget.counterfactual_sample(p_logits, q_logits, p_observed, rng)`. This method serves a similar role as the counterfactual sampling for Gumbel-max SCMs: it allows us to draw a sample from the counterfactual distribution `q_logits`, conditioned on a particular observation from `p_logits`."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 221
        },
        "id": "aSKW3xTHlShl",
        "outputId": "945c0243-fc32-42a3-d608-0bd1a0858d6a"
      },
      "source": [
        "x_obs = 7  # for example\n",
        "rng = jax.random.PRNGKey(1234)\n",
        "\n",
        "def sample_ctf_gadgets(key):\n",
        "  y_from_gadget_1 = gadget_1_at_init.counterfactual_sample(p_logits, q_logits, x_obs, key)\n",
        "  y_from_gadget_2 = gadget_2_at_init.counterfactual_sample(p_logits, q_logits, x_obs, key)\n",
        "  return (\n",
        "      jnp.zeros([10]).at[y_from_gadget_1].set(1.),\n",
        "      jnp.zeros([10]).at[y_from_gadget_2].set(1.),\n",
        "  )\n",
        "\n",
        "from_p, from_q = jax.vmap(sample_ctf_gadgets)(jax.random.split(rng, 1000))\n",
        "from_p = jnp.mean(from_p, axis=0)\n",
        "from_q = jnp.mean(from_q, axis=0)\n",
        "_, axs = plt.subplots(nrows=2)\n",
        "axs[0].imshow(from_p[None, :], vmin=0)\n",
        "axs[1].imshow(from_q[None, :], vmin=0)"
      ],
      "execution_count": 25,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "<matplotlib.image.AxesImage at 0x7f5b8645df10>"
            ]
          },
          "metadata": {},
          "execution_count": 25
        },
        {
          "output_type": "display_data",
          "data": {
            "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXwAAAC7CAYAAABmfSVyAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAOSUlEQVR4nO3dbYydZZ3H8e/PlgdbVqhitLbEYnTBRuOio4uSGEMhihrYRE0g0YCR1GRF0DVR0Y3Z+GbRGB9eGBNSdI0SJKlkrYb4lEKyG7NdCtQHQKSigdaqpSCC8lT73xfnrjOZndJp7zNzj+f6fpLJ3A/XnOvfq3N+58x9zrmuVBWSpMn3jKELkCQtDgNfkhph4EtSIwx8SWqEgS9JjTDwJakRvQI/ybOT/CDJPd33VYdo95ckO7qvLX36lCQdnfR5H36STwMPVtVVST4KrKqqj8zR7tGqOqFHnZKknvoG/t3AG6pqT5LVwM1Vddoc7Qx8SRpY32v4z6uqPd32b4HnHaLd8Um2J/mfJP/Us09J0lFYfrgGSX4IPH+OUx+fuVNVleRQfy68sKp2J3kRsDXJT6vql3P0tRHYCLByRV51+ouPPew/YCHd99TKQfs/6MRlfx66BO5/bM6XZxZdHlk2dAnUCQeGLgGA2p+hS+DY4/cPXQIABx44ZugSqCXyFpjHHtj1QFU9d65zi3JJZ9bP/Afwnara/HTtpl5xfP3v90456trG4fLfvHrQ/g9680k/GboE/mXHO4YuAYBj/vtZQ5fA4699dOgSAHhy3/FDl8C6l/xu6BIAePwrq4cugaeeOfwDMMDtmz50a1VNzXWu72PSFuDibvti4FuzGyRZleS4bvtk4Czgzp79SpKOUN/Avwo4N8k9wDndPkmmkmzq2rwU2J7kx8BNwFVVZeBL0iI77DX8p1NV+4ANcxzfDlzabf8IeHmffiRJ/S2RlxkkSQvNwJekRhj4ktQIA1+SGmHgS1IjDHxJaoSBL0mNMPAlqRFjCfwkb0pyd5Kd3bz4s88fl+T67vy2JOvG0a8kaf56B36SZcAXgfOA9cBFSdbPavYe4KGqejHwOeBTffuVJB2ZcTzDfw2ws6ruraongW8AF8xqcwHw1W57M7AhydKYWk6SGjGOwF8D3D9jf1d3bM42VbUfeBh4zuwbSrKxWyhl+959fxlDaZKkg5bUi7ZVdXVVTVXV1HOfM/xCF5I0ScYR+LuBmSuVrO2OzdkmyXLgRGDfGPqWJM3TOAL/FuAlSU5NcixwIaOFUWaauVDK24Gt1WepLUnSEes1Hz6MrsknuQz4HrAM+HJV3ZHkk8D2qtoCXAN8LclO4EFGDwqSpEXUO/ABqupG4MZZxz4xY/txYGksiipJjVpSL9pKkhaOgS9JjTDwJakRBr4kNcLAl6RGGPiS1AgDX5IaYeBLUiMWawGUS5LsTbKj+7p0HP1Kkuav9ydtZyyAci6jqZFvSbKlqu6c1fT6qrqsb3+SpKOzWAugSJIGNo65dOZaAOUf52j3tiSvB34BfLCq7p/dIMlGYGO3++iy1Tvv7lnbycADR//jO3t2Px5fHM/N9ByLO8ZTxdLQbyw+N75CloBeY3HfGAtZAnreR5aMFx7qxFgmT5uHbwPXVdUTSd7LaLnDs2c3qqqrgavH1WmS7VU1Na7b+1vmWExzLKY5FtNaGItFWQClqvZV1RPd7ibgVWPoV5J0BBZlAZQkq2fsng/cNYZ+JUlHYLEWQLk8yfnAfkYLoFzSt995GtvloQngWExzLKY5FtMmfiziSoOS1AY/aStJjTDwJakRExv4h5vuoRVJTklyU5I7k9yR5IqhaxpSkmVJbk/ynaFrGVqSk5JsTvLzJHclee3QNQ0lyQe7+8fPklyX5Piha1oIExn4M6Z7OA9YD1yUZP2wVQ1mP/ChqloPnAm8r+GxALgC3yV20BeA71bV6cAraHRckqwBLgemqupljN58cuGwVS2MiQx8nO7hr6pqT1Xd1m0/wuhOvWbYqoaRZC3wFkafBWlakhOB1wPXAFTVk1X1h2GrGtRy4JlJlgMrgN8MXM+CmNTAn2u6hyZDbqYk64AzgG3DVjKYzwMfBg4MXcgScCqwF/hKd4lrU5KVQxc1hKraDXyG0UwRe4CHq+r7w1a1MCY18DVLkhOAbwIfqKo/Dl3PYkvyVuD3VXXr0LUsEcuBVwJfqqozgD8BTb7WlWQVoysApwIvAFYmeeewVS2MSQ38w0730JIkxzAK+2ur6oah6xnIWcD5SX7N6BLf2Um+PmxJg9oF7Kqqg3/tbWb0ANCic4BfVdXeqnoKuAF43cA1LYhJDfzDTvfQiiRhdJ32rqr67ND1DKWqrqyqtVW1jtHvw9aqmshncfNRVb8F7k9yWndoAzB7DYtW3AecmWRFd3/ZwIS+gL1Ys2UuqkNN9zBwWUM5C3gX8NMkO7pjH6uqGwesSUvD+4FruydF9wLvHrieQVTVtiSbgdsYvavtdiZ0mgWnVpCkRkzqJR1J0iwGviQ1wsCXpEYY+JLUCANfkhph4EtSIwx8SWqEgS9JjTDwJakRBr4kNcLAl6RGGPiS1AgDX5IaYeBLUiMMfElqhIEvSY0w8CWpEQa+JDXCwJekRhj4ktQIA1+SGmHgS1IjDHxJaoSBL0mNMPAlqREGviQ1wsCXpEYY+JLUCANfkhph4EtSIwx8SWqEgS9JjTDwJakRBr4kNcLAl6RGGPiS1AgDX5IaYeBLUiMMfElqhIEvSY0w8CWpEQa+JDXCwJekRhj4ktQIA1+SGmHgS1IjDHxJaoSBL0mNMPAlqREGviQ1wsCXpEYY+JLUCANfkhph4EtSIwx8SWqEgS9JjTDwJakRBr4kNcLAl6RGGPiS1AgDX5IaYeBLUiMMfElqhIEvSY0w8CWpEb0CP8mzk/wgyT3d91WHaPeXJDu6ry19+pQkHZ1U1dH/cPJp4MGquirJR4FVVfWROdo9WlUn9KhTktRT38C/G3hDVe1Jshq4uapOm6OdgS9JA+sb+H+oqpO67QAPHdyf1W4/sAPYD1xVVf95iNvbCGwEeMYzjn3VyhUnH3VtkyT7DwxdAhw4+t+TsTqwBMailkANQC2B/5O/f/mfhi4BgF/8ZMXQJSwZj/DQA1X13LnOHTbwk/wQeP4cpz4OfHVmwCd5qKr+33X8JGuqaneSFwFbgQ1V9cun6/dZf7emXn3GPz9tbQsuGbb/zjH7hr9T5bEnhi4BgPrzY0OXAI89PnQFABx4Yvj/k+/+atvQJQDwxhf8w9AlLBk/rM23VtXUXOeWH+6Hq+qcQ51L8rskq2dc0vn9IW5jd/f93iQ3A2cATxv4kqTx6vu2zC3Axd32xcC3ZjdIsirJcd32ycBZwJ09+5UkHaG+gX8VcG6Se4Bzun2STCXZ1LV5KbA9yY+BmxhdwzfwJWmRHfaSztOpqn3AhjmObwcu7bZ/BLy8Tz+SpP78pK0kNcLAl6RGGPiS1AgDX5IaYeBLUiMMfElqhIEvSY0w8CWpEWMJ/CRvSnJ3kp3dvPizzx+X5Pru/LYk68bRryRp/noHfpJlwBeB84D1wEVJ1s9q9h5GUye/GPgc8Km+/UqSjsw4nuG/BthZVfdW1ZPAN4ALZrW5APhqt70Z2NDNny9JWiTjCPw1wP0z9nd1x+ZsU1X7gYeB58y+oSQbk2xPsv3Jp4afA16SJsmSetG2qq6uqqmqmjr2mJVDlyNJE2Ucgb8bOGXG/tru2JxtkiwHTgT2jaFvSdI8jSPwbwFekuTUJMcCFzJaGGWmmQulvB3YWn0W05UkHbFe8+HD6Jp8ksuA7wHLgC9X1R1JPglsr6otwDXA15LsBB5k9KAgSVpEvQMfoKpuBG6cdewTM7YfB94xjr4kSUdnSb1oK0laOAa+JDXCwJekRhj4ktQIA1+SGmHgS1IjDHxJasRizYd/SZK9SXZ0X5eOo19J0vz1/uDVjPnwz2U0U+YtSbZU1Z2zml5fVZf17U+SdHQWaz58SdLAFms+fIC3JflJks1JTpnjvCRpAY1lLp15+DZwXVU9keS9jFa/Ont2oyQbgY3d7qNb/+tf7+7Z78nAAz1vY1I4FtMci2m9xmLZ6jFW0svOcdzIpPxevPBQJ9J3luIkrwX+rare2O1fCVBV/36I9suAB6vqxF4dz6+27VU1tdD9/C1wLKY5FtMci2ktjMWizIefZObzgPOBu8bQryTpCCzWfPiXJzkf2M9oPvxL+vYrSToyizUf/pXAlePo6whdPUCfS5VjMc2xmOZYTJv4seh9DV+S9LfBqRUkqRETG/iHm+6hFUlOSXJTkjuT3JHkiqFrGlKSZUluT/KdoWsZWpKTus/F/DzJXd077pqU5IPd/eNnSa5LcvzQNS2EiQz8GdM9nAesBy5Ksn7YqgazH/hQVa0HzgTe1/BYAFyB7xI76AvAd6vqdOAVNDouSdYAlwNTVfUyRm8+uXDYqhbGRAY+TvfwV1W1p6pu67YfYXSnnuuT0BMvyVrgLcCmoWsZWpITgdcD1wBU1ZNV9YdhqxrUcuCZSZYDK4DfDFzPgpjUwJ/vdA9NSbIOOAPYNmwlg/k88GHgwNCFLAGnAnuBr3SXuDYlWTl0UUOoqt3AZ4D7gD3Aw1X1/WGrWhiTGviaJckJwDeBD1TVH4euZ7EleSvw+6q6dehalojlwCuBL1XVGcCfgCZf60qyitEVgFOBFwArk7xz2KoWxqQG/m5g5gRta7tjTUpyDKOwv7aqbhi6noGcBZyf5NeMLvGdneTrw5Y0qF3Arqo6+NfeZkYPAC06B/hVVe2tqqeAG4DXDVzTgpjUwD/sdA+tSBJG12nvqqrPDl3PUKrqyqpaW1XrGP0+bK2qiXwWNx9V9Vvg/iSndYc2ALPXsGjFfcCZSVZ095cNTOgL2Is1W+aiOtR0DwOXNZSzgHcBP02yozv2se7T0Wrb+4FruydF9wLvHrieQVTVtiSbgdsYvavtdib0U7d+0laSGjGpl3QkSbMY+JLUCANfkhph4EtSIwx8SWqEgS9JjTDwJakRBr4kNeL/AE0ITwkB+bpJAAAAAElFTkSuQmCC\n",
            "text/plain": [
              "<Figure size 432x288 with 2 Axes>"
            ]
          },
          "metadata": {
            "needs_background": "light"
          }
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "1vo9-V1al-mq"
      },
      "source": [
        "As before, these correspond to the 7th row of the full joint distribution shown in the previous section."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "cSr_Jeuskoit"
      },
      "source": [
        "### Sampling differentiable continuous relaxations"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ESIlao02cy6s"
      },
      "source": [
        "To train the gadgets, we additionally provide a method `relaxed_sample`, which continuously relaxes the Gumbel-max operations inside each gadget to instead be Gumbel-softmax operations. The default temperature is specified when initializing the gadget, and determines the tradeoff between higher gradient variance and more gradient bias.\n",
        "\n",
        "Below, we compare the discrete samples with their continuously relaxed counterparts, where each row is a new sample. Note that the discrete sample is always the same as the position of the maximum in the continuous version."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 247
        },
        "id": "WkzzWMELbz8z",
        "outputId": "7f97e624-f825-4b34-bc3e-04038071bbd5"
      },
      "source": [
        "def draw_relaxed_from_p(key):\n",
        "  k1, k2 = jax.random.split(key)\n",
        "  return (\n",
        "      jnp.zeros([10]).at[gadget_1_at_init.sample(p_logits, k1)].set(1),\n",
        "      gadget_1_at_init.sample_relaxed(p_logits, k1),\n",
        "      jnp.zeros([10]).at[gadget_2_at_init.sample(p_logits, k2)].set(1),\n",
        "      gadget_2_at_init.sample_relaxed(p_logits, k2),\n",
        "  )\n",
        "\n",
        "g1_samples, g1_relaxed_samples, g2_samples, g2_relaxed_samples = jax.vmap(draw_relaxed_from_p)(jax.random.split(jax.random.PRNGKey(3), 15))\n",
        "\n",
        "_, axs = plt.subplots(ncols=4, figsize=(10,4))\n",
        "axs[0].imshow(g1_samples, vmin=0)\n",
        "axs[1].imshow(g1_relaxed_samples, vmin=0)\n",
        "axs[2].imshow(g2_samples, vmin=0)\n",
        "axs[3].imshow(g2_relaxed_samples, vmin=0)"
      ],
      "execution_count": 26,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "<matplotlib.image.AxesImage at 0x7f5b862cab10>"
            ]
          },
          "metadata": {},
          "execution_count": 26
        },
        {
          "output_type": "display_data",
          "data": {
            "image/png": "iVBORw0KGgoAAAANSUhEUgAAAlAAAADVCAYAAACYAM4HAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAZZklEQVR4nO3de4yc5XXH8d+Z2Zt3bWzHF2xsF3MnQEmAlUMuaiOggZAIojRSoSKXBslN27SkjRSRRgpSpKatEiVtkyiRm1BTFUFakjQkgTqIhCAUMBhjm4uNAWNjG2MbG+P7XmZO/9hx5Di73vfM+8zs7Lzfj2R5PXP2fZ/Z/e3D4d2ZOebuAgAAQHaliV4AAADAZEMDBQAAEEQDBQAAEEQDBQAAEEQDBQAAEEQDBQAAENTRzJN1Wbf3qK+Zp0QC5158OHPt5q1Den1vxRq4HHV293l378zM9eUjw7ETDFeCK5K8Wg1/ToR1Bn9Ug4+hnrczsVLs2zwwd0rm2qF9e1U5fKhhOWIvGl/k516SNq7rbdBK6nNUhzToAw3di7qsx3sskKPoz1kdqzeLXRdp9N7VjBxZZ2fwM7J/H44MH9Bg9cio34mmNlA96tM77MpmnhIJrFixJnPtkqu3NnAlI7p7Z+rt770lc/20p3eFju9734guSdVDR8KfE1GeNzdUX92zN1TvlXjTWOruDtW/vPSizLVbln0tupwQ9qLxRX7uJenq097eoJXUZ6U/2PBz9FifLu+8JnO9Dw+Fjm/lcnRJsinZ/0dFkqoHDwZPEGvQVqx4MlR/9cLLQvWS1DFvXuwTAvvdr3d/f8z7cv0Kz8yuMbPnzexFM7s1z7FQXOQIeZEhpECOEFF3A2VmZUnfkvR+SRdIutHMLki1MBQDOUJeZAgpkCNE5bkCtUTSi+6+yd0HJd0t6fo0y0KBkCPkRYaQAjlCSJ4GaoGk45/wsq12GxBBjpAXGUIK5AghDX8SuZktlbRUknrUWq/SwORxfI66p8yY4NVgMmIvQgrkCMfkuQK1XdKi4/69sHbbb3H3Ze7e7+79nYq9ageFEM5RRzcvP8dvYS9CCvEcWU/TFofWk6eBekLSOWZ2hpl1SbpB0r1ploUCIUfIiwwhBXKEkLp/hefuw2b2aUkrJJUl3e7uzyZbGQqBHCEvMoQUyBGicj0Hyt3vk3RforWgoMgR8iJDSIEcIYJZeAAAAEFNHeUCpFA+WtHUjYFxKwODoeNbT/yJoeXg5/gpU2P1u2OjWUrR0S/TYuMfJMnePBSqn//rgcy1rx6Mz+ZDWtHRLCtejY1+qeccrcbKJZWmT8teHxx/pOgMTEnVXa+H6jsWxt6poToztndde9VZofp9N80K1UvSrEd3huoPXjA7c23loa4x7+MKFAAAQBANFAAAQBANFAAAQBANFAAAQBANFAAAQBANFAAAQBANFAAAQBANFAAAQBANFAAAQBANFAAAQBANFAAAQBCz8BosOh+qFWdDRda00fc0cCU1Q8PSa4F5T72xOW/VAweDC5IsOLPKyuVQ/ZY/f2uo/vS7tobqS7sCswWPmRKb/zdlw2uZa0tHh6KrQWLtsHc1mg9XVNkTm1MZEZ1TJ0k+GJv9qWo1VF7aeyBUv/ljp4fqp+yMz8H013aH6jvOeEvmWjvJl4crUAAAAEE0UAAAAEF1N1BmtsjMfmlmz5nZs2Z2S8qFoRjIEfIiQ0iBHCEqz3OghiV91t1Xm9k0SU+a2QPu/lyitaEYyBHyIkNIgRwhpO4rUO6+w91X1z4+IGm9pPgz3lBo5Ah5kSGkQI4QleRVeGa2WNIlklaOct9SSUslqUe9KU6HNpU5R6WpTV0XJg/2IqRAjpBF7ieRm9lUST+Q9Bl333/i/e6+zN373b2/U915T4c2FclRl8VePo9iYC9CCuQIWeVqoMysUyNBu9Pdf5hmSSgacoS8yBBSIEeIyPMqPJP0PUnr3f1r6ZaEIiFHyIsMIQVyhKg8V6DeLemjkq4wszW1P9cmWheKgxwhLzKEFMgRQup+Erm7PyLJEq4FBUSOkBcZQgrkCFHMwmuwRs+His6rkgo4s8pie2L1yNH4KQaC++6hI6HyhV8PzhicPStU/vLNZ8aOL2nx3TtC9Yffdlrm2ur+zuhygOYzk3V1ZS73oeHQ4QcXz4muSIOBnzNJmvLA2lB9dNZeeckpofpZH94YqpckdcX2i55N2ffT0sDY3zNGuQAAAATRQAEAAATRQAEAAATRQAEAAATRQAEAAATRQAEAAATRQAEAAATRQAEAAATRQAEAAATRQAEAAATRQAEAAATRQAEAAAQxTHiSa8Zg4MjA4iVXH27gSo5xqVLJXF2ZMyN09Hp+KPyUvlj9lu2h+tK0qbHjD8eGli7+9oZQvSQpOBh173XzM9cOPxEczhx07sWHtWJF9lwXbgC3ivmYo4ZO7dXWT16WuX7RV1eFjt+59qXoklQeGIh9QtVjxz//7FD9aX+yKVRfedfvh+olqfTUC6H6rR/OvhcN3jH2oGKuQAEAAATlbqDMrGxmT5nZT1MsCMVEjpAXGUIK5AhZpbgCdYuk9QmOg2IjR8iLDCEFcoRMcjVQZrZQ0gckfTfNclBE5Ah5kSGkQI4QkfcK1L9I+pyk6lgFZrbUzFaZ2aohBZ/chqII5WiwerR5K8NkEcrQ7j3ZX4SAQgnlaPjwoeatDC2n7gbKzD4oaZe7P3myOndf5u797t7fqe56T4c2VU+Ouko9TVodJoN6MjRnVrlJq8NkUU+OOnpjr75Fe8lzBerdkq4zs82S7pZ0hZn9V5JVoUjIEfIiQ0iBHCGk7gbK3T/v7gvdfbGkGyT9wt1vSrYyFAI5Ql5kCCmQI0TxPlAAAABBSd6J3N0fkvRQimOhuMgR8iJDSIEcIQuuQAEAAAQxCw+Tjleqqhw4kLm+tD42T2o4OktKUsfw3NgnLF4YKq++tCVU75UxX4U9qvLc2aF6Saoc3BOq777i9cy1pR/H5uxFbVzXy6y3SajV5nJ27jykRV95PHN9ed6poeNX5s2MLkn2zIux+nNOD9VXX9gcqt/z0eyzAiVp7i9fDdVLkubMCpV37cs+/89OshVxBQoAACCIBgoAACCIBgoAACCIBgoAACCIBgoAACCIBgoAACCIBgoAACCIBgoAACCIBgoAACCIBgoAACCIBgoAACCIWXhBkVlMktpi3lbkMWz02Hy0upjJOjqzl0+ZEjp8qVyOrkiDZ88P1R9c1BOqn94T+1EtvbIrVF/PzK2yWah+1h9nn+dXHhiMLmfSK+LeEtVye5FMsuzXIQbOic3CK/9qbXRB2v537wjVz1t5JFTftX9OqH7ufZtC9Qcuj83mk6Spj8Tmnc7+98cy175UPTTmfVyBAgAACMrVQJnZDDO7x8w2mNl6M3tnqoWhOMgR8iJDSIEcISLvr/D+VdL/uftHzKxLUm+CNaF4yBHyIkNIgRwhs7obKDObLukPJH1Cktx9UFLxnriAXMgR8iJDSIEcISrPr/DOkLRb0n+Y2VNm9l0z60u0LhQHOUJeZAgpkCOE5GmgOiRdKunb7n6JpEOSbj2xyMyWmtkqM1s1pIEcp0ObiufIjzZ7jWht7EVIgb0IIXkaqG2Strn7ytq/79FI+H6Luy9z93537+9Ud47ToU3Fc2SxtwBA22MvQgrsRQipu4Fy99ckbTWz82o3XSnpuSSrQmGQI+RFhpACOUJU3lfh/bWkO2uvVtgk6c/yLwkFRI6QFxlCCuQImeVqoNx9jaT+RGtBQZEj5EWGkAI5QgTvRA4AABDELLygRs+fYh5WBu7yoexvz1J5o/Fv5VJ6JPZ9OyV4fA/W39eEHA0HZ+FZaMZg9BFPfkX8WS7afld+6KnYJwTm7B2z4BtPhurLc2Oz7VSphMp9+rRQfe/PYpmQpL9avy5U/82LsufIjo69z3EFCgAAIIgGCgAAIIgGCgAAIIgGCgAAIIgGCgAAIIgGCgAAIIgGCgAAIIgGCgAAIIgGCgAAIIgGCgAAIIgGCgAAIIhZeA1WtFlPzWAdZZVnzspc74ePhI7vQ8PRJckuPDv2CRs2hcr33HBJqP4D71wUqi9d1Beql6TSvgOh+sqcGdmLN/wquBqMpxX3okm/37nLA7PhSr29ocNXLzwzuiJp7cbYOfa+Earf/acXh+rnPL4vVB/eSyXd9s+Xhern9j2fvXho7OtMXIECAAAIytVAmdnfmtmzZvaMmd1lZj2pFobiIEfIiwwhBXKEiLobKDNbIOlvJPW7+0WSypJuSLUwFAM5Ql5kCCmQI0Tl/RVeh6QpZtYhqVfSq/mXhAIiR8iLDCEFcoTM6m6g3H27pK9KekXSDklvuvvPUy0MxUCOkBcZQgrkCFF5foU3U9L1ks6QdJqkPjO7aZS6pWa2ysxWDWmg/pWiLdWTo8Hq0WYvEy2MvQgpkCNE5fkV3lWSXnb33e4+JOmHkt51YpG7L3P3fnfv71R3jtOhTYVz1FXieZ34LexFSIEcISRPA/WKpMvNrNfMTNKVktanWRYKhBwhLzKEFMgRQvI8B2qlpHskrZb0dO1YyxKtCwVBjpAXGUIK5AhRud6J3N1vk3RborWgoMgR8iJDSIEcIYJ3IgcAAAhiFl5QK86TKhqvVFTd92bm+vKC+aHjV2dMjS5J2rg5VL77pthsu9JQqFzVnbtjnxCtl1QZjs0MLAdmhlkd8whxco3ei6J7ozT590crlVTqCTyRvBS7ZlHeFH8bKps7J1Q/vH1HqH7ur3aF6qsvbw3Vf2DNzlC9JP3sbbHHrBnTw+cYDVegAAAAgmigAAAAgmigAAAAgmigAAAAgmigAAAAgmigAAAAgmigAAAAgmigAAAAgmigAAAAgmigAAAAgmigAAAAgpo6C+/ciw9rxYrs85JacU5SK66paAbn9emVm5dkrj/9x3tDx/fnXoouSaXFC0P1s25/PHb8yLwtSfZ7C0L1/sr2UL0k2ZQpofrqrBnZi99gTCdan7vLAzMeS12doeNbR/zn4MClp4XqpwVnWmowNpizvGBeqP4b914Wqpekc07dEvuEajV7rdmYd3EFCgAAIIgGCgAAIGjcBsrMbjezXWb2zHG3vcXMHjCzF2p/z2zsMjHZkSOkQI6QFxlCKlmuQC2XdM0Jt90q6UF3P0fSg7V/AyezXOQI+S0XOUI+y0WGkMC4DZS7PyzpxGfhXi/pjtrHd0j6UOJ1oc2QI6RAjpAXGUIq9T4H6lR331H7+DVJp45VaGZLzWyVma3avSf7qxVQCHXlqHLoUHNWh8kiU46Oz9CQBpq3OkwGde1FQ360OatDS8r9JHJ3d0l+kvuXuXu/u/fPmVXOezq0qUiOyn19TVwZJpOT5ej4DHUq9rYQKI7IXtRpPU1cGVpNvQ3UTjObL0m1v3elWxIKhBwhBXKEvMgQwuptoO6V9PHaxx+X9OM0y0HBkCOkQI6QFxlCWJa3MbhL0qOSzjOzbWZ2s6R/kvRHZvaCpKtq/wbGRI6QAjlCXmQIqYz7PvHufuMYd12ZeC1oY+QIKZAj5EWGkEpTB05tXNdbuFlyK17NPvtPYtZeFl07DmnRlx/NXO8dsflTkdlWv7F7T6jcOmM/ehv/4eJQ/eKfxOZVdW1/LVQvSRqKnaMyLfsTt7009vwp1KfRe1ER967q9F4duiL74+77yZOh45fK8Rde9f5sdai+cslbQ/UDX94fqu9+f2zO5tn/uC9UL0nDBw6E6i3w3wQfHvu/B4xyAQAACKKBAgAACKKBAgAACKKBAgAACKKBAgAACKKBAgAACKKBAgAACKKBAgAACKKBAgAACKKBAgAACKKBAgAACKKBAgAACGrqMOEiaocBm5EhpEuuPtzAldSn1DclVO9Dw+Fz2MwZsXOcPz1Uf+4Xnw3VKzgQed+HYsOKJemUu58I1b9xfm/m2srz/L9datG9iEHo47Oqq+NINfsnBIcDW1dXcEVxpf1HQvWdt8X2uv0fmR+qn/nQy6F6STpyxfmh+r4H12euteGxB5uzSwEAAASN20CZ2e1mtsvMnjnutq+Y2QYzW2dmPzKzWEuKwiFHyIsMIQVyhFSyXIFaLumaE257QNJF7n6xpI2SPp94XWg/y0WOkM9ykSHkt1zkCAmM20C5+8OS9p5w28/d/dgTRR6TtLABa0MbIUfIiwwhBXKEVFI8B+qTku4f604zW2pmq8xs1ZAGEpwObYocIS8yhBSy52jwUBOXhVaTq4Eysy9IGpZ051g17r7M3fvdvb9T3XlOhzZFjpAXGUIK4Rx19TVvcWg5db+NgZl9QtIHJV3p7p5sRSgUcoS8yBBSIEeIqquBMrNrJH1O0h+6e+u98Q8mBXKEvMgQUiBHqEeWtzG4S9Kjks4zs21mdrOkb0qaJukBM1tjZt9p8DoxyZEj5EWGkAI5QirjXoFy9xtHufl7DVgL2hg5Ql5kCCmQI6TCO5EDAAAENXUW3rkXH9aKFdnnKxVxtlIrinwfNvqeBq5kxOBpfdryF5dnrj/r9u2h41e3vRpdkipvmRqqLz+/NVRfPRp72X10/t/Mnz4Xqpck74m9km3Wuv2Za8tHYrP8otiLxlfExxw11GfaeVln5vrFv479zAwtPjW6JPlZ80L1gz2x+Xzdj8Tmcs7YFZuFN7xzV6hekqY+GnvOv1cD8wtPgitQAAAAQTRQAAAAQTRQAAAAQTRQAAAAQTRQAAAAQTRQAAAAQTRQAAAAQTRQAAAAQTRQAAAAQTRQAAAAQTRQAAAAQU2dhbdxXW9ovtKKV7PPqpKY3VQUXTsOafGXnshc/+Gnt4WO/98XLoguSeVXYvObKvv2xU7gsVlPlf3BWXIenw1V6o7N9Sq9lP37YAND0eWERPciNMZk3+O7Xx/UGcs3Z673RbG5cB0bXgmuSLJTYnM5qzOnxY7f1RWq9207YsfvvyhUL0lbroo9ht+7P7D/bhj78XIFCgAAIGjcBsrMbjezXWb2zCj3fdbM3MxmN2Z5aBfkCCmQI+RFhpBKlitQyyVdc+KNZrZI0vskxa8xooiWixwhv+UiR8hnucgQEhi3gXL3hyXtHeWur0v6nKTYkzNQSOQIKZAj5EWGkEpdz4Eys+slbXf3tYnXgwIhR0iBHCEvMoR6hF+FZ2a9kv5eI5c6s9QvlbRUknrUGz0d2hQ5QgqRHJEhjCbXXlSOvfoL7aWeK1BnSTpD0loz2yxpoaTVZjZvtGJ3X+bu/e7e36nYy57R1urPkZEj/EbmHLEXYQx170VdpSlNXCZaTfgKlLs/LWnusX/XAtfv7q8nXBfaHDlCCuQIeZEh1CvL2xjcJelRSeeZ2TYzu7nxy0K7IUdIgRwhLzKEVMa9AuXuN45z/+Jkq0HbIkdIgRwhLzKEVHgncgAAgCDz4IytXCcz2y1pyyh3zZY0Eb9vnqjzTuS5G33e0919TgOPT44m/rzNOHdDc0SGWuLckzpDEjlqgfM249xj5qipDdRYzGyVu/cX5bwTee6JfMyNVrSvaRHz22h8L9v/vM1QtK9pEfMr8Ss8AACAMBooAACAoFZpoJYV7LwTee6JfMyNVrSvaRHz22h8L9v/vM1QtK9pEfPbGs+BAgAAmExa5QoUAADApNHUBsrMrjGz583sRTO7dZT7u83s+7X7V5rZ4gTnXGRmvzSz58zsWTO7ZZSa95rZm2a2pvbni3nPe9yxN5vZ07XjrhrlfjOzf6s95nVmdmmCc5533GNZY2b7zewzJ9Q07DE30kRkqHbcCcsRGUqPvYgc5cVe1JwM1Y7bmjly96b8kVSW9JKkMyV1SVor6YITav5S0ndqH98g6fsJzjtf0qW1j6dJ2jjKed8r6acNetybJc0+yf3XSrpfkkm6XNLKBnzdX9PIe1k05TG3W4YmOkdkqD1yxF7UPjliL5qYDLVajpp5BWqJpBfdfZO7D0q6W9L1J9RcL+mO2sf3SLrSzCzPSd19h7uvrn18QNJ6SQvyHDOx6yX9p494TNIMM5uf8PhXSnrJ3Ud7s7fJZkIyJLV8jshQDHvR6MhRduxFo2t0hqQWylEzG6gFkrYe9+9t+t1v+m9q3H1Y0puSZqVaQO0S6iWSVo5y9zvNbK2Z3W9mF6Y6pySX9HMze9LMlo5yf5avSx43SLprjPsa9ZgbZcIzJE1IjshQWhOeI/ai3zHZcjThGZIKuRdJLZSjcYcJtwszmyrpB5I+4+77T7h7tUYuBx40s2sl/a+kcxKd+j3uvt3M5kp6wMw2uPvDiY59UmbWJek6SZ8f5e5GPua2NUE5IkNthL3od5CjOhRtL5JaL0fNvAK1XdKi4/69sHbbqDVm1iFpuqQ9eU9sZp0aCdqd7v7DE+939/3ufrD28X2SOs1sdt7z1o63vfb3Lkk/0sil3+Nl+brU6/2SVrv7zlHW1bDH3EATlqHa8SYkR2QoOfYicpQXe1HzMyS1WI6a2UA9IekcMzuj1kXeIOneE2rulfTx2scfkfQLd8/1RlW13zl/T9J6d//aGDXzjv1u2syWaOTrkmKz7DOzacc+lvQ+Sc+cUHavpI/VXr1wuaQ33X1H3nPX3KgxLnU26jE32IRkSJq4HJGhhmAvIkd5sRc1P0NSq+XIm/iMdY08Q3+jRl698IXabV+SdF3t4x5J/yPpRUmPSzozwTnfo5Hf266TtKb251pJn5L0qVrNpyU9q5FXUjwm6V2JHu+ZtWOurR3/2GM+/twm6Vu1r8nTkvoTnbtPI+GZftxtDX/M7ZihicwRGWqfHLEXtVeO2Iual6FWzRHvRA4AABDEO5EDAAAE0UABAAAE0UABAAAE0UABAAAE0UABAAAE0UABAAAE0UABAAAE0UABAAAE/T/R0Y5Gz2hrpQAAAABJRU5ErkJggg==\n",
            "text/plain": [
              "<Figure size 720x288 with 4 Axes>"
            ]
          },
          "metadata": {
            "needs_background": "light"
          }
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "CAwhvZ1Wihls"
      },
      "source": [
        "If we simultaneously draw relaxed samples from $p$ and $q$, again using the same exogenous noise, we can obtain a differentiable estimate of the implicit coupling between them (shown only for gadget 2):"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 483
        },
        "id": "1uLsfU7VhyTS",
        "outputId": "34032f18-1eea-4f76-9692-51aee26095d9"
      },
      "source": [
        "def gen_soft_pair(key):\n",
        "  soft_x = gadget_2_at_init.sample_relaxed(p_logits, key)\n",
        "  soft_y = gadget_2_at_init.sample_relaxed(q_logits, key)\n",
        "  return (soft_x[:, None] * soft_y[None, :])\n",
        "\n",
        "soft_pairs = jax.vmap(gen_soft_pair)(jax.random.split(jax.random.PRNGKey(1), 4*8))\n",
        "_, axs = plt.subplots(nrows=4, ncols=8, figsize=(16,8))\n",
        "for i in range(4):\n",
        "  for j in range(8):\n",
        "    axs[i, j].imshow(soft_pairs[np.ravel_multi_index((i,j), (4,8))])"
      ],
      "execution_count": 27,
      "outputs": [
        {
          "output_type": "display_data",
          "data": {
            "image/png": "iVBORw0KGgoAAAANSUhEUgAAA6EAAAHSCAYAAAAdcs1sAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAgAElEQVR4nOzdeZQdV30v+u+vztCnZ3WrZcsabMm2LFsmWLblEYIxxnggwTdZEMx7GW9uRBhy4SZZ75H73iLrkmTdkPsykEUC+AZuIEAMVjCYYDTdgEd50ORBkmVL1tgauyW1ejp9zqn6vT+65UhNd+2fpN11ukrfD0vLPfzYtbu+VbvOPkNtUVUQERERERERJSGodweIiIiIiIjowsFJKBERERERESWGk1AiIiIiIiJKDCehRERERERElBhOQomIiIiIiCgxnIQSERERERFRYvL12nBRSlqS5tgaCaxzZPcyMxpGxrb8kYJ792qtZmvM40o6/Tjeo6qzfLSVLzVrQ3NnfM2Q7W/UkYqhyLYjJJdzF1mPL3GXaLVqa0rc29RS0dRW/9BBbzl2deZ0wfxCbM1re22bCqruc00sWQOImhvcbdWMx0TVcBzWQlNb2mjoV8V2TJysHvWWY665WQsdjvNx2NZWZLg6FI7bckRo268mYjghjbTJnSP6h0xt+RxXTdfHQvz5eopWDBlZd6nH65Dk3WO0+brtcak5vzk2aAnxOaae5dipw0qAPnNs78zp7Lnx59uhLbacwxlNzprcgG1ctTx+jNrd2xvdZtndluF6DADBkKX/tkHnZPUIz8cMmOx8NE1CReReAF8AkAPwD6r65+N+3wDgGwBuBNAL4MOqujuuzZI049bCvbHbDZobLd0DIvcIFw4M2tpSw0XPeMHLz5rtrAmP9pjaMk9WY/ToIbyOzQDQKiKf8ZFjQ3Mnrn3/p2O327npuKl/0eu7nDVqfDCba2tx1kizcTAyTFbDg4dMTUmDexDXaxbG/r6n7w1s37sS8JjjgvkFvLBqfux2f/4TH3X2HQCaut0P2nM7u01tDd16hbOmeMx2wS4cOOasiY72mtqK3ubuV37f0djfHy3vwba+pwCPORY6OjHvP/+X2O12bbaNX8Nd7uN+zqOx3XlLdLLfXWQ8ty3nkGkcB1BZ6s4x98Sm2N/36CG8rpsAjzlaro+5eZc4+w4AtT37nTUS2B4QquFai8g4RnfMdDfVP2BqS0dGTHVxpuL6WEIzbpG7zrtvZoHhyVfAnJGF5D0+2e7BVOQ4e24BX3rsstjtfv7am03967tnqbOm45l9prbCw/HXGAAYfM/1prZan3zD3dZtV5raalm/x13keKHgaHk3tp14Ekj6fKzDOXQhWKsrJjwonI8yRCQH4O8A3AdgCYCPiMiScWW/DeC4ql4J4K8BfP78uku+qSq2YxOW4p0AsAXMMZVUI7y253Fcv+j/BJhjaqlG2Nr3BJbN/EWAOaaWqmK7bsRS+XmAOaYWr4/ZwByzQTXC1uM/xbKuBwDmmGmW9yPeDGCHqr6pqhUADwN4YFzNAwC+Pvb1CgB3iXh8zxSdtz4cQyNa0CQtwOibY5hjCvUNdqOpoRNNpU6AOabWiephNOXb0ZRvB5hjanFczQbmmA3MMRtOVA6jKT+D18cLgGUSOhfA6e8N2D/2swlrVLUGoA+A+702lJgRDKOEM97ezBxTaKRyEg3FttN/xBxTaCQcRGOu9fQfMccUGh1Xz/jMFXNMIV4fs4E5ZsNIOIDG3BkfqWKOGZXo3XFFZLmIrBeR9VV1fwiapqfTc6yNGD9rS9PO6Tke7eXnG9Lq9BzDQZ6PacXrYzackSPO/7OqVB+n53jiGK+PacXzcXqzTEK7AZx+x5J5Yz+bsEZE8gDaMfpB4TOo6kOqukxVlxWkdG49pnPSgEaUccZtMb3kmG/gncaS1FBsw0jl5Ok/8pLjrJnGD+OTFw25ZgyHZ9ysx0uOOevNtsiL0XH1jBtx8fqYQlN1fSzAdjdR8mOqcpzRyetjkhpyLRgOz7gpGc/HjLJMQl8EsEhEFopIEcCDAB4bV/MYgN8Y+/qDAP5N1eM90+m8taEDwxjAsA4Co/fGZo4p1NY8B0MjvRgeOQ4wx9RqL1yMoVofhmonAeaYWhxXs4E5ZgNzzIb24sUYqp3AUK0PYI6Z5ry3tqrWROSTAFZhdImWr6nqFhH5HID1qvoYgK8C+CcR2QHgGEYPmFiSCxA4ltHQuRcZ/gQAhrvxB3tsS0JY1lSz3g5+6O3znDWlZ21vn9MBw63qY9agDJDDYr0Bm6KnAOBaAH/iI8fcSIT21+P/Buk94e47jMuvGJde0LIhI8tyA4DzVuKAfekYGNaOzB2efH/lACxpuwMbt34d8Jjjtu5ZuOn/+VhszUUvGG67DkCH3AtRRv2GJTsANG9036retA4igGjQvXRMZDy3c2+4l70Ih+K3tzi4AeuPPgp4zDGoAC2OmBqP2tYvDWqGpResS2MYzg/zOWRZ7iGyjROFE4a18WIe1wQQLMZSbNInAY85SkMRweXxS0Ls/G+2V0sv/z33x6SkaFub2DKuRsdtS3Id+MhiZ82cf7MtmRRufd1ZE7d2dA7A4uhGbAqfBjzmmDTL2uQAoBV/S9FZjh2vS7TE3HsmkBwW6/XYpH4f5xza2oLPv/222JqobHss1/rIi86amsflP5p+sN5UFxq2WfqRu+8AUPMwF1ysP4f1h/4F8Hk+intJoZV7bfvrnrmGpW+s+8FyPyXDGvOj2zSs0254TAsAQXubs0YNj6sAAJM8LDSNWKr6OIDHx/3ss6d9XQbwIVtPqF66ZA66cnOwNvzOq6r6ZwBzTKNZjQsxq3EhVu79G+aYYrNyczErNxdryt9ijinWJZegC5dgra5gjik2K5iDWcEcrKk+zBxTrEsuQZdcgrXRI8wxxTiuXhgSvTERERERERERXdg4CSUiIiIiIqLEcBJKREREREREieEklIiIiIiIiBLDSSgRERERERElhpNQIiIiIiIiSgwnoURERERERJQYTkKJiIiIiIgoMfn6bVqAfPzmtaFga0rVWRKI2Noy1Lm3NqraknPWNOaMzwPI9Hy+QIbKCF5+I7amVqnaGotCDz0aa6pcdheNjHjbnuUYBACtVpw14cFD59ubsyahovF4/P5XY45aq7lrIuv+Mmyz6t4eAGgYmepMbRn+RoT+jmersElxfGl834JK0dTW8MXusbD9xSZTWzDtL/d4CQDSYOi/8Xwcnt3srClarx3WC4OlqZERRG/siq254vcvNrVVO9rro0uj1HAOGff9nG9uc9ZEA4OmtizbNJ2z9WA9vixNFW3ntlrGTLWNX9JYchcNDZnasuwLydnGCfgb7lG5qBF7/+N1sTXzPv+8qa3cNVc6a3TvAVNb0cCAe3uLFpra0j37nTUyf46trW7DYxjrce/+E+3UPQ7cM/d6Y1s+B3xDW8bz0bQ541gY9h7zts3JTM+ZDREREREREWUSJ6FERERERESUGE5CiYiIiIiIKDGchBIREREREVFiOAklIiIiIiKixHASSkRERERERIlxTkJFZL6I/EREtorIFhH51AQ17xaRPhHZPPbvs1PTXTpXZR3ChvAnWBf+GACuZY7pVNZBrK/9G56tPg4wx9RijtlQ1iFsiH6KddFKgDmmVlmHsEGfwDpdBTDH1OK4mg3laBAvDq/CM4PfB5hjplnWCa0B+ANV3SgirQA2iMgaVd06ru4pVf0F/10kHwSCRcF1aJNOrA2/sw3AJ5hj+ggCXJVbijbpxJrqw8wxpZhjNggEi+Q6tEkH1kaPMMeUEggW4e2jOeoK5phSHFezQSBYXFyGttxMrB74OnPMMOckVFUPAjg49nW/iGwDMBfA+IPhrGgYIjrRF1sTeFxcOjIulqyRvwVo29e7FxwOrYtxR+e3UG0DimhA8dSCtxEAPzkC0NDjqtAUewwWUUIRJejo4sbecswN19D68pHYmvDkSVtjoeFYNR7PUb9hpWrL9jA65riLbOe/Dg+f1/aKY/9T+D0fS0ciXP2l+DEl6LXlqI0Nzpro8FFTW1GlaqqzEMO+t2pc/6azJow5JhpQQgNKp44bbzkKBJKPvzxre4utrUOGReHF+MkcdbdlXQhdSu7jy5q1nufh1SCNaEDjqW+95QgRSKEYX1Is2JoqxrcDAL2/eLWpra4nu5010aH468Epe3/Hvc25f/G8qa1cR7uzRtrbJv1d09g/AMBOfzkWToaYvzr+8apaH6N1H3KWWK4vo4WG69VBW46WMTp3uMfWVnnEWSPB5GNJETkU0QKtjQA+z0cL42MA8uOsPhMqIgsAXA9gohHlNhF5SUR+LCLXTvL/Xy4i60VkfVXLZ91Z8qYI5pgF3nKshLYnaWhK+MuxxhzryF+OcD+IoynD62M2+MuR42o9+cuR4+q0Y3k7LgBARFoA/AuAT6vq+KfSNwK4TFUHROR+AN8HsGh8G6r6EICHAKAtmMmnG+qgpjUAuALArzHH9KqNPvXvLcf20mzmWAe+z8f2pjnMsQ6858hxtS54fcyGWlQBfObYzHG1Hnw/zmmTTuY4zZheCRWRAkYnoN9S1e+N/72qnlTVgbGvHwdQEJEurz2l8xZphJexDgCOMcf0ijTCy9GzAHNMtUgjvKzMMe04rmYDc8yGSENsOvQYwBxTLdIIL4fPAMwx0yx3xxUAXwWwTVX/apKa2WN1EJGbx9rt9dlROj+qiq1Yj2a0AsDhiWqY4/SnqtiqL6BZmGOajea4Hs1oA5hjanFczQbmmA2qilePrEJLsRNgjqmlqtgavoBm4fUx6yxvx30HgF8D8IqIbB772X8FcCkAqOqXAXwQwMdEpAZgGMCDqvx073TSh14cwl60oB0AloxlyRxTpg89OKR7mGPKjZ6PzDHtOK5mA3PMhhPlbhwY2IqWYhfAHFPrhPbgoO5Gi/J8zDrL3XGfBhB7SzxV/SKAL/rqFPk3Q7rwXnwQALBWV2xV1WXja5jj9DdDZuG9uQ8DANaG32GOKTVDuvBe+RAAYG30CHNMKY6r2cAcs6GjcR7uveIPAQArd/5/zDGlOoJZuDt4EACwpvowc8yws7o7LhEREREREdH54CSUiIiIiIiIEsNJKBERERERESXGvE7olIgcnyGu1rxtSl3b+vdCb9vESMVfW9OUFAsILpsXX3TUdsOysG/8MlATMH7uXApFQ42/wz8aHrYVivt5n1znDFtbR21lFuWLC9j2BxfH1iz+UsnUVtA34KyJjvaY2pKF8901Q7YF4dVwfEUDg6a2cpfMdrd1os/UFgyHvVkYIuiNb1CNfyMq7vFLa8Yx2uO4qpG/50512HbsJE0xeofIOFKp2toyXPsksOVjvo6a2nJvM/X3GVGF1hw5Gc8NNeQ987kjprYiwzU5Mh5f89a4xzmNQlNblvHXetz7JCMVBG/si62x/YVAeNJ9fYRxf5m2Zx3vDdv02ZbPh9peBTlbnceMLmR8JZSIiIiIiIgSw0koERERERERJYaTUCIiIiIiIkoMJ6FERERERESUGE5CiYiIiIiIKDGchBIREREREVFiOAklIiIiIiKixHASSkRERERERInJ123LqtBq/GLooeP3013tsGHh6JQvxh015DF8xczYmqaabVFfMSyErKGxrWLBXdNYMrWFnGHx4vKIqanA0C/M7DC1haO2MpMICIYltkQMi8uPtuVxFerQ0FZkPIesdRaW89bnfjgbQXyOzt+PEbHVpVowTZ+HVYVW4q9/2n3I1tY0XTg+OnbCWaO1agI9mWKOscJ6TQMMdYdsFwUdMVyvDMcNAAS7DzprrH+h65gfbczamkdX5ICHWmJL5F7bY4CRd/+cs6Zx815TW1HvMWdN7effbmqruGGHs6Z63RWmtgpb97iLAsPjKgAwPIz2SQq2aZGO1OE4zKBpegUmIiIiIiKiLDJNQkVkt4i8IiKbRWT9BL8XEflbEdkhIi+LyA3+u0rn62l9HOt0NQAsYY7p9cSOv8czu74KMMdUe3JwBZ4d+gHAHFON42o2MMdseKr6GNZVfwwwx1R7ouebeKb3OwBzzLSzeSX0TlVdqqrLJvjdfQAWjf1bDuBLPjpH/t2IOwBgK3NMt5vmfwRgjqm3rPEegDmmHsfVbGCO2XBj/j0Ac0y9mzo+ADDHTPP1dtwHAHxDRz0HYIaIXOKpbUoOc8wG5pgNzDEbmGM2MMdsYI7ZwBwzwDoJVQCrRWSDiCyf4PdzAew77fv9Yz+jaWYTngKAa5hjeokI1u/7DsAcU06wYXgNwBxTj+NqNjDHLBBsrP0UYI6pJgDWn/hXgDlmmnUS+k5VvQGjL39/QkTedS4bE5HlIrJeRNZXYbuTGPmzDHfiFnkvALwBXzlW3He0Jb9uvvRXcfvC3wI85hgOMsek3dx4L25r+kXAY46VcNhrH8ltSsZVXh8Txxyz4ab8Xbi1cA/gc1zt47iatJs7/gNu7/wQwPMx00yTUFXtHvvvEQCPArh5XEk3gPmnfT9v7Gfj23lIVZep6rICGs6tx3TOStJ46ssafOVYbJ6KrlKMUqH11Jfecsw1M8eklYK39rm3HIu5xvG/pik2JeMqr4+JY47ZUJKmU1/6G1fbOa4mrZR7a0kcno8Z5pyEikiziLSe+hrA+wC8Oq7sMQC/Pna3qlsB9Kmqe/EoSkyoNdT0rTXXAjDHVKpFFdTCt57NY44pVdMqz8cM4LiaDcwxG5hjNtS0ilr01pqxzDHDLKuyXgzgURlduDwP4NuqulJEfhcAVPXLAB4HcD+AHQCGAPzW1HSXztUIyngZ60Y/3QtcA+BPmWP6VGpD2NT9L6e+ZY4pVdEyNpd/cupb5phSHFezgTlmwwjKeKn29KlvmWNKVaJhbOpbeepb5phhoqp12XBLxzxdeuenYmsOPVg2tRWFOWfNFV+omdrKHRtw1oRv7jW1tWr/BmfN+z70m6a2cpvfcNZIwfKcArDq+Fc3THLL67PW3DVfl7z/v8TWdG46bmpLt7/prglDU1u59jZnjTQ3OWtGG3O/az3stj0BJw3ut4PokstNba154Y+95njNB+JznLVmj6ktHXJ/fibq7ze1leua6d5eteqsAQAdHHLWRCO2z4zkZsxwtzXk3h4ArCl/y1uObS1z9Za3fTS2Jn/0pKktbSq5a3Yaj4mae/zVyHYtkpx7vLfKzexw1tQOHTa1tVZX+MuxbZ4uW/aJ+O19+2umtu6/45edNdpke5tacNJ9TIcHDpnaeuOrS5w1l/2j7bYVxZ++5C4yHjc+z8f2/Cy9rf2XYmukxfZRCG1xvyW062u2Y3Xvnyx21jRtsJ3bj29a7ay5d+EtprZk0UJnzdBC97UdAJ5+7P/ydz62ztWbrv94bE3wjOEYBJCb2emsifps10etVpw1uQ73GAcAYZ/7umB5XAXYr+8Wa6oP+8tROvUWuSu2RgpFU1uWfU//brLro68lWoiIiIiIiIicOAklIiIiIiKixHASSkRERERERInhJJSIiIiIiIgSw0koERERERERJYaTUCIiIiIiIkoMJ6FERERERESUGE5CiYiIiIiIKDH5em04NzCClqd2xNZcseMiU1ui7kXOde8BU1tRxbAAbRSa2rrzt/6Ts6b08nZTW6aF7yX55xTyfSOYuTI+RzUuXOxz8V/LwstiXVDZsF+1VjM1paH72Am27DS15VN+sIauF4/F1kTHT5ja0qp7X1j2AwBEJ/oMbUWmtkzbNIwlABANDLqbqlVNbXkVCMLmQmxJrt+2GHdUcl8eJJcztQVDRhLYckQgtjqLfN0ugbFkpIqGnUdia675ysdNbS08usW9PesC7SMj7hrD+Q8Ac77n3mbjjoOmtmqWczuynds+aRgiPH48vsj1+7Nw5F22HBtq6501oXEsvO/yW501OlI2taWvvuasKW3xeP4bXXV5L9Z89x9ja+678nZTWz3vv8pZM+uJblNbYfchZ82Jexab2pqx9g1nzcl3X2lqq33dPndRYHy8utdW5kvQWDLVhR4fr17I+EooERERERERJYaTUCIiIiIiIkoMJ6FERERERESUGE5CiYiIiIiIKDGchBIREREREVFiOAklIiIiIiKixDgnoSKyWEQ2n/bvpIh8elzNu0Wk77Saz05dl+lcDOpJPBeuwnPhKgBYwhzTaTDqw7rhH2Ld8A8B5phag9qP56LVeC5aDTDH1BrUfjyna/CcrgGYY2rx+pgNHFezYaB6DM8c/CaeOfhNgDlmmnORNFXdDmApAIhIDkA3gEcnKH1KVX/Bb/fIl2Zpw625ewAAa8PvbAUwD8wxdZqDdtzW+IsAgNWD32COKdUsrbhV3gcAWBs9whxTqllacSvuBgCs1RXMMaV4fcwGjqvZ0FLoxDsu+VUAwMq9f8McM+xs3457F4CdqrpnKjpDiWkDc8wC5pgNzDEbmGM2MMdsYI7ZwBwzzPlK6DgPAvjnSX53m4i8BOAAgD9U1S1xDWkYIuo7GbsxGS6bOqWGmmh42NSWT40v7HTWhINDtsbU8FdqaGsL6ATwt5P87uxyLBVQu2pu7MYKu4+YOhUdPOwu0sjUVlBqcNZIsWBqC7mcsyQ8dtzUlOTd2wwu6jK1hV3+cpx55Un85qMrYzf3pd/7kKlbpUODzppgzwFTWyM3XOmsKRy3jRP5I+6Mwp5eU1tyzRXOmqD7qKktHPGX41ULe7D2W1+L3dyVP/1NU7fuunK7s2b3f15kaiu/371ftVo1tYXWZmeJRJarArD7w3OcNfM+f8jUFkKP42qlitr+7tjNLfgfx2zdGjJcY0RMbZmuQ0bNP9zkrKnVjMfENL0+AgCC+OuHGK4vo+0YMnqbe7wEgNzO/c6aaMA9jgNA5fZrnTX5n2w0tRU0GK7bjY2mtnDMX45bD83C0v/+8djNXVx+3tStrnXux0PRUdt1SA3nR8cLtvEr7B9w1rRvsLUV9RrGpsD8Gpjf89Fh6B2LTXUNP37xfDc17UnePUXU0DiuTjJEmyehIlIE8AEAfzTBrzcCuExVB0TkfgDfB/Azj05EZDmA5QBQQpN10+RRNDqJawfwyAS/PvscG9qnrrM0qWj0AZW3HLvmFKeuszQp3zleOvdsn1ckH3znyOtjfXi/PjLHuvB9PhZaO6auszQpjqvZdzZvx70PwEZV/ZmXq1T1pKoOjH39OICCiPzMyzmq+pCqLlPVZQVxP+NF/vXgEAAMecux4H5Vgvw7OrQL8JhjaycnL/XQU9kLeMxx1kzjqyrklfdxFbw+1gNzzIae6j7AY475Rj7OqYeesBvg+ZhpZzMJ/QgmeSuuiMwWGX0/j4jcPNau7f0ElKjD2AsAE75Xgjmmx8HB7QBzTL2D5R0Ac0y9w6MfV2KOKcfrYzYcrOwEmGPqHartBphjpple/hCRZgB3A/joaT/7XQBQ1S8D+CCAj4lIDcAwgAdVPX54hLwItYZjOAIAJ079jDmmTy2qond4D8AcU62mVfRW9gHMMdVCreHY6BP1zDHFeH3MhppW0VvtBphjqtW0it7wAMAcM800CVXVQQAzx/3sy6d9/UUAX/TbNfItJ3ncgQ9gra5465PEzDF98kEBd132cazc9VfMMcXyUsB7Zv1HrDryJeaYYjnJ447cL2Ft+B3mmGK8PmZDXgp4T8evY9Wx/8kcUywvBdzZ/CBWD36DOWbY2S7RQkRERERERHTOOAklIiIiIiKixHASSkRERERERImp47oMAjgWaRbDwsXmrVUqpjq1LHJuXPRamt1rEkl/v6mt0eXLpp+rLu/F6kf+Mbbmmq/EL/J8ysK/dy+8rOWyqS29eoGz5uQC25pRYYN7kfDOH203tYWLZjpLtn6q09bWx2xlFgeOd+L//d7/EVtz5d4eU1ty0r3IeThsy7Fh33H39oZHTG2pYTFurdZMbeWOnnDWRENDprZ8euPVFtx35e2xNVdWt5ra2hO4j3tUXjG1VfN5v4ifuVn/uZv/P/Y7azQyLsbtkwDiuD4GHTNMTUXDw4btWZ+PNlyIjFkHzY3urQ3YLnxas523iROBFOIfZknRtkazFAvOmsPL2kxtXdznvsZIxX09BoDud7v7f9lPbceXtLa6i2YYaoBJ7ql6bsICMDgv/rgWy3gJIGouOWsCxzFz2kbd22txbw9wjzcAoE22tlBwH6tjN7addhq7bY/Jp+lDcq80NFz7zvPazldCiYiIiIiIKDGchBIREREREVFiOAklIiIiIiKixHASSkRERERERInhJJSIiIiIiIgSw0koERERERERJYaTUCIiIiIiIkoMJ6FERERERESUGE5CiYiIiIiIKDGiqvXZsMhRAHvG/bgLQE8duuNDmvp+marO8tEQc6wr5ji5NPWdOU4uTX1njpNLU9+Z4+TS1HfmOLk09Z05Ti5NfZ8wx7pNQiciIutVdVm9+3Eu0tx339K8L9Lcd9/SvC/S3Hff0rwv0tx339K8L9Lcd9/SvC/S3Hff0rwv0tx339K8L9Lc91P4dlwiIiIiIiJKDCehRERERERElJjpNgl9qN4dOA9p7rtvad4Xae67b2neF2nuu29p3hdp7rtvad4Xae67b2neF2nuu29p3hdp7rtvad4Xae47gGn2mVAiIiIiIiLKtun2SigRERERERFlGCehRERERERElJhpMQkVkXtFZLuI7BCRz9S7P2dLRHaLyCsisllE1te7P/XCHLOBOWYDc8wG5pgNzDEbmGM2MMfpoe6fCRWRHIDXAdwNYD+AFwF8RFW31rVjZ0FEdgNYpqppWTTWO+aYDcwxG5hjNjDHbGCO2cAcs4E5Th/T4ZXQmwHsUNU3VbUC4GEAD9S5T3T2mGM2MMdsYI7ZwByzgTlmA3PMBuY4TUyHSehcAPtO+37/2M/SRAGsFpENIrK83p2pE+aYDcwxG5hjNjDHbGCO2cAcs4E5ThP5encgI96pqt0ichGANSLymqo+We9O0VljjtnAHLOBOWYDc8wG5pgNzDEbMpHjdHgltBvA/NO+nzf2s9RQ1e6x/x4B8ChGX+q/0DDHbGCO2cAcs4E5ZgNzzAbmmA3McZqYDpPQFwEsEpGFIlIE8CCAx+rcJzMRaRaR1rxalncAACAASURBVFNfA3gfgFfr26u6YI7ZwByzgTlmA3PMBuaYDcwxG5jjNFH3t+Oqak1EPglgFYAcgK+p6pY6d+tsXAzgUREBRvfnt1V1ZX27lDzmmA3MMRuYYzYwx2xgjtnAHLOBOU4fdV+ihYiIiIiIiC4c0+HtuERERERERHSB4CSUiIiIiIiIElO3z4QWpUFLaK7X5i9o/Tjeo6qzfLRlylGMjVneGS62xiSfc2+uWjO1Zd2mieXt78bN9WvSORo7xrf4nxWv52OhWUulGbE1MjRiaqs6s9FZU+gdNrWlUWSqs5DGknt7w2VbWznD87DGw/lk1Ostx0KxWUuNHbE1wVDF1JY2FHx0adSQe79KwbY9rVYNbdkeomjRXWc97n3mWJSSlsQxrhrHSyk1OGu0bPsbE+fzGmo8Ib1eH4OSNgat8UXGHE1jYZN77AUADNnG3+lIjMfEST3mLceuzpwumB8/Pm05YttUYBh+c8eHTG1J4L4OWa+hpjGzFpraMm3zPB+vmkZ4EbkXwBcw+gHef1DVPx/3+wYA3wBwI4BeAB9W1d1xbZbQjFvkLlvvk2Q5MVL6ILtHD+F1bAaAVhH5jLccg/fGbldy7gkhAGjk3q/WtnJdnc6a2sFDprakwX3xt9KKe+Ry/Y090UFsDzcCvnN0nI/W/WD5G9N6Dvk0JedjaQZuuv7jsdvNb3rD1L9Dv3Kds2b2P71iaisaGHAXie2NOcHVi93be2mbqa1ci+OBJQCtxT9Z1VPrxmuVFwGfOTZ24Ibbfy92u40b9zj7DgC1K+e4i4wPJIINrzlrcnNmm9oKD7jH39xFtgeElcu6nDX5ja/H/n5KcpRm3NpwX+x2dcQ2ccwtuNJZE76xy9QW1N+TQpaxXApFW1uB4UAM4x9AT8X1sTFoxW1tD8RuVyvuJ1UAIBp2TxzlbW8ztaUv2sZfk4Qf+wal+CcTe8IDeK3q93xcML+AF1bNjyvBtV+Mv36e0rrHfQ7N+O5GU1tBi/sFuWhg0NRWbvZF7rZO9Jnaivr7nTWStz1RuKb68IQXLOdVX0RyAP4OwH0AlgD4iIgsGVf22wCOq+qVAP4awOdNvaLEqCq2YxOW4p0AsAXMMZVUI7wWrsf1+TsA5phaPB+zQTXCtsrzuKF0F8AcU4s5ZgOvj9mgGmFb9QXcUHwPwBwzzfLU880Adqjqm6paAfAwgPFPCT0A4OtjX68AcJdYX2unRPThGBrRgiZpAUbfz8IcU6hPj6FJWpljyvF8zIa+qBdNQSuaRt+qxxxTijlmA6+P2dAX9Y7myPMx8yyT0LkA9p32/f6xn01Yo6o1AH0AZvroIPkxgmGUcMbnDJhjCo1gGA1oOv1HzDGFeD5mQ1mHMO4zf8wxhZhjNvD6mA1lDKEkzPFCkOiNiURkOYDlAFA6c6CgFGGO2cAcs+H0HBsa2uvcGzpXZ+TouLkUTV8cV7PhjBwD3kQzrU7P8dK5dbsXK03C8kpoN4DTP8k7b+xnE9aISB5AO0Y/KHwGVX1IVZep6rIC/N3shdwa0IgyzvhAPHNMoQY0YgRn3HGNOabQVJ2PxQIfLCWpJE0o6xk3jPBzPhaZY5KmLEdx38mZ/Jmq62NRjHerJS9KaEJZ/ec4a6btxpaUHMsk9EUAi0RkoYgUATwI4LFxNY8B+I2xrz8I4N9UefvL6aQNHRjGAIZHL7QC5phKbdKJIe3HsA4AzDG1eD5mQ1swE0NRP4aifoA5phZzzAZeH7OhLZiJIe3HUMQcs845CR17r/UnAawCsA3Ad1V1i4h8TkQ+MFb2VQAzRWQHgN8H8Jmp6jCdm0ACLMZSbMJTAHAtmGMqBRJgce5GbKw9ATDH1OL5mA2BBLi6eDM2ltcCzDG1mGM28PqYDYEEuLpwEzZW/jfAHDNN6vXEQZt0qnOdUJ83ujL+nbku9+eaw56fecV/Yob+S964sHfNtgaVxdrokQ2qusxHW23BTHWtg2ZdR8i1/hcAwLgQOq6IXwsKOIu1BGcYPp9lXHctGjSsEWZcoH314De85dgezNRbS/fH1vz4zedMbb3nN/+Ts6bx5X3OGsC+npWFZbH36tsWmtpyrTkIAMHFtjUOV775l/5yzHfpbS3x69mFhrW/ACA/5xJnTe3AQVNbXteXazasqTZoW1MNgeHtWcZz2+e42l68WG+f/ZHYmvDwUVNbwQzD54Qt6zMCCA8fcTfVZPscZDTkXsjdtZbgKdLe5qwJj9j2l9fro3Sqax1t67lh2RdRuWxqa9ryuFblWl3h73zMdemtLR+IL6oa1wk1rAubW3KVqa1wy3ZT3XRkXTt2TeXb3nJsuHS+XvJ/fyq2ZsEPDY9DARR73Y/ldJN7XWUACIrux7WRcR3aXFuLuy3D41AA0KphzXfjPG2ycdW2OjgRERERERGRB5yEEhERERERUWI4CSUiIiIiIqLEcBJKREREREREieEklIiIiIiIiBLDSSgRERERERElhpNQIiIiIiIiSgwnoURERERERJQYTkKJiIiIiIgoMfl6d8ALVX9thZG/tiw04e1Nhcjj/vfF5zFhMR33wVlQVUQjI7E17/y9j5raatu4w1kT9vXb+lWrmupMKu62Clv2mJoKh4fdRQcPm9ryqTy3hG1/dHVszZK/OGRq60fPPuasuf+OXza1Fe3tdtZI3nY52vup65w18//iBVNbwRULnDUyVDa1hb22MouwpYi+W+bF1rQ/Zxtzyosvcdao8enohv4BZ43Mn2NqK9h3wF0z+yJTW0OLupw1pWcN5ywAnLSVmTQ3Qpe+PbYkt/kNU1M/3vGss+b9t/2iqS0dHHIXGR8LRQOD7qLrrjK1FTYWnDWFY4a+A8CrtjILjSJErn1mfSxneWyy76CtrRTzem03khBo6M3F1uSG4h8HnRKUa86a0NTS6OMvXzTpOcx54iuhRERERERElBhOQomIiIiIiCgxnIQSERERERFRYjgJJSIiIiIiosRwEkpERERERESJcU5CRWS+iPxERLaKyBYR+dQENe8WkT4R2Tz277NT0106V2Udwobop1gXrQSAa5ljOpWjQbw4vArPDP0AYI6pVY4G8eLIGjxT/iHAHFNruNaPFw6vwFMHvgEwx9QqRwN4cfDHeGbgewBzTK3hykm8sOubePqNrwDMMbX4ePXCYbknfg3AH6jqRhFpBbBBRNao6tZxdU+p6i/47yL5IBAskuvQJh1YGz2yDcAnmGP6CASLi8vQlpuJ1YPfYI4pJSJYXLgBbcFMrB7+JnNMKZEAizvehfbiRVi592+YY0oJAiwu3YS2XBdWn/xfzDGlAhFcPfsutDVeglWv/hlzTCk+Xr1wOF8JVdWDqrpx7Ot+ANsAzJ3qjpFfDdKINuk49W0E5phKDUET2nIzT33LHFOqQZrQFjDHtCvlmtFefGstS+aYUqPj6ltrjTLHlGootKKt8a11cZljSvHx6oXDtjr4GBFZAOB6AM9P8OvbROQlAAcA/KGqbjnv3nlcwNXKtPCylaH/GhqXs/W7L4rwlKOIQEoNsRuTUsnWq2rFWWJta+DSVmdN45b4RYvf2mZLs7sosi0QLIa8XfvzLYP+cgTgPMbaNx02dSs0LGhvXqja63Hvzsh8/hv6FVXMi3F7y7F0sIIln9sTu7HakR5Tp+6/+8POmmjXTlNbWnMv7K0jtkXCL/vKa86a0LA9AIh27nbWaGQ+Br3lGOWA4ZnxzxG3NxRNnRrpcF/m1Xh3iFLe3VbUZBu/JOcef6Mm23hv+RsbDX0f429cHSojeHFbbElkuO4BwP13/LKzJtwXf+6/RT0uaG8YC+Wl101N5QJx1kTWx0wecww7m3HyvptiN9Z8wJZj8aVdzprHt/zE1NZ9l9/qrNHQ+NikWHC3ZRyjTa5bbKtb7y/HhkPDWPCXr8RuToeHTd0KLdeFyHas6oj5mHZvsr/fW1sm5/kYzTwqi0gLgH8B8GlVPTnu1xsBXKaqAyJyP4DvA1g0QRvLASwHgBKazrnTdO5qWgOAKwD8mpccxTBBI+9qWgV85sjzsS6855hrmdoO04R851ho6Rj/a0oAx9Vs8J1jsYnnYz3UwgrAx6uZZnr+U0QKGJ2AfktVvzf+96p6UlUHxr5+HEBBRLomqHtIVZep6rICjK/2kDeRRngZ6wDgmK8ci2J8lZO8iTTC5pNrAI858nxMXqQRXo6eBXyej0HjlPebzjQVOeZLfLCUtEgjbO73PK7y+pi4SCO8HD4D8HxMtSgK8fLO7wJ8vJpplrvjCoCvAtimqn81Sc3ssTqIyM1j7fb67CidH1XFVqxHM1oBYML3VTLH6U9VsWXgCTTnZgDMMbVUFVv1BTQLz8c0Y47ZoKrYMvgEmnMdAHNMLVXF1vAFNEsbwBxTS1Wxdc9jaC51Acwx0yxvx30HgF8D8IqIbB772X8FcCkAqOqXAXwQwMdEpAZgGMCDqnX4QCdNqg+9OIS9aEE7ACwZy5I5psyJ2mEcGHkDLblOgDmmVh96cEj38HxMOeaYDSdqh3GgsoPjasqd0B4c1N1oUZ6PaXZiYB8O9r6MlsaLAOaYac5JqKo+DSD20+Kq+kUAX/TVKfJvhnThvfggAGCtrtiqqsvG1zDH6a+jMBv3dC0HAKzqeYg5ptQMmYX35kZv/LM2/A5zTCnmmA0dhdm4p/N3AACrjv1P5phSHcEs3B08CABYU32YOaZUR+uluHvZHwMA1qz/b8wxw4z3xCMiIiIiIiI6f5yEEhERERERUWI4CSUiIiIiIqLEcBJKREREREREibHcHXdqiEAKxdiS4LK5tqYGh5014bHjprZW7nreWXPfoneY2go63QscH3/HPFNbHc8dcNZoQ/z+fMtrtjILbWxA+LbLY2uG5tjWZir21Zw1Q7MLprae+4svO2vef+O9pra6f+lSZ01+2HZTtq7NA86a4YuMa1n9q63MFz3eZ6urunNEPW5iZ9imhqHH7UX+2rJuslZD2OO4S31k+xt11z7T9pIW9Y1fs/zc1aP/FoW+Ecx+7M3YGmfOY9o87q/wpLutYNtOU1tRueyskTd2mdqacajFWRMetz0G8Eu9jQNSrhg253HM8TlGW/tlGJo0Sv7akTs2iLaHX4wvMv6NoWG/3nfl7aa2ovKQqc5Cq4bjy6eNW5PdHgAU8pDZs2JL5KhtXNVK1VkTDRnzkdh7v45t0HjcBzl3jfExQBL4SigRERERERElhpNQIiIiIiIiSgwnoURERERERJQYTkKJiIiIiIgoMZyEEhERERERUWI4CSUiIiIiIqLEcBJKREREREREieEklIiIiIiIiBKTr9uWVZ2L44Y7d5vb8uX+JXc4a6LBPlNb0eCgs6ZtxSFTW7Vw+iwue4bBYcjzr8aWtBSMh5nhb2zI29p6/9MfcNbUDu41tXXJPxgWe49sC1VHhgWOSznDYsP1YNz3ErgXXva5pjr9O5EA0tAQW6PGsUSam91Fw8OmtnyO0WI4DrVWMzZmWCTcyt+fiFprA469Z2FsTedTBVNbg2+7xF1kfDq68acjzhpZMM/UVrB7v7utubNNbQ0unumsaX7CPfYCAGyX9+TlLoDXDMTyN07Tx0I0/YUhcOxEbEk0XLa35YvH62PaHlxdAKMaERERERERTRemSaiI7BaRV0Rks4isn+D3IiJ/KyI7RORlEbnBf1fpfD0d/QjrolUAsIQ5ptdTlR9gXfVHAHNMNZ6P2cAcs+GJk9/FM/2PAswx1Z4Of4h14UqAOaYax9ULw9m8Enqnqi5V1WUT/O4+AIvG/i0H8CUfnSP/bpR3A8BW5phuN+bvAphj6vF8zAbmmA03Nd8HMMfUuzG4E2COqcdxNft8vR33AQDf0FHPAZghIoYPotA0wxyzgTlmA3PMBuaYDcwxG5hjNjDHDLBOQhXAahHZICLLJ/j9XAD7Tvt+/9jPziAiy0VkvYisr8J9gwPyb5M+CQDXMMd021j7CcAcU8/3+VhR400VyCvfOdbK7pvakV8CYP3gKsDnuKocV5Mn2BT9FOD1MfW8Xx8jXh+nG+sk9J2qegNGX/7+hIi861w2pqoPqeoyVV1WQPwdHMm/ZfIe3BLcDQBvgDmm1k2Fu3Fr4T6AOabaVJyPRSl57SO5TUWO+ZLhzsTk1c0t78ftrQ8APsdV4biatGXBe3BL7h6A18dUm5LrY8Dr43RjmoSqavfYf48AeBTAzeNKugHMP+37eWM/o2mkJI2nvqyBOaZWSZpOfckcU4znYzYwx2woBW9N/JljivH6mA0cVy8MzkmoiDSLSOuprwG8D8D4hSEfA/DrY3eruhVAn6oe9N5bOmeh1lDTt9ZJC8AcU4k5ZgNzzAbmmA01rTLHDOD5mA3M8cJhWX3+YgCPyuii3nkA31bVlSLyuwCgql8G8DiA+wHsADAE4Lemprt0rkZQxsv67KkF1a8B8KfMMX1GUMZLtSdPfcscU4rnYzYwx2yo6DA2Df7vU98yx5QaQRkvR0+f+pY5phTH1QuHqGpdNtwmnXqL3BVbI3nLHBnQyPA3RKGprdy1i5014dbXTW1B3O92DpqbnDUAoMPDtm0arKk+vGGSW16ftfaG2Xr7vF+NrYnabZ9vkuGKsybssO2v15cXnTVXf2q7qa3KzVc5a4IR2/FV3NfrrIlabftr9St/6i1Hy/loOTcAQHfucdZElaqzZrTQtl9NRp9Iiy8puo8bANAR940qrOOXz/OxYf58nfMHn46tufoLtncsffHJf3bWfPLO+HP/lHDfAWeNFAumtvZ//Dpnzdy/fsHUVrDwUmeNDNtuSrJy3xf8jaul2Xrbpb8eW6P7bU/6B10z3UWGcwMAavvdx06utdXUVtjf76wJWlpMbQUz2p01tW7b/lobfsfvuBq8N77I+BgsZ/gbwxN9prYSZzy+LI+ZoJGpqbXRI/5yDGbqrYV747sVGq9VhmtabtYsU1Ph0aO2bU5HQc5U5vN8bJ45X6+9P/762LnR/RgNAGTA/ZjcMl4CgOTc+8J6fAUN7s8ve338ZTy3JzsffS3RQkREREREROTESSgRERERERElhpNQIiIiIiIiSgwnoURERERERJQYTkKJiIiIiIgoMZyEEhERERERUWI4CSUiIiIiIqLEcBJKREREREREibGtpj5VHIvVBu1tpmbUsPBqNDhkamvbJ9wLQi/+lG1R9aC50VkTLZxna2u3e7F35I1xHrGVmQigufjnMqKibVHioObuf5S3PW/S2FZ21oij36eEDR6fqwkMbRn/Rq9EIIVibMmbfxz/+1MWfm6Bsya313A8A4iG3TlaScF9fMm8S0xtRbv2OWsCw+LyALyejw3dQ1j0R5tia2ojI6a2PvFz9ztrwpO7TW1ZaLViqpv7N+vdbdVqprbCHbvcRaqmtnyqzChg33+IPxbn/8g23vfc1OWsiWxDNGZ9f9BZU7vmUlNb+W17nTXRFXNNbR1e2uqsuegHtuPe5/kohQLyF8fnGB62bXDXp6911iz4y1dMbaHqfsykxuNeq+5zLX+R+xgcLXQf01o2XhOO2spMVKFh6KiJ/G1uyPZ4NdU87i+rXDnEjO0D8UXH+kxtqeU6aj2HXMfW2bRlufb53PfneX3kK6FERERERESUGE5CiYiIiIiIKDGchBIREREREVFiOAklIiIiIiKixHASSkRERERERIlxTkJFZLGIbD7t30kR+fS4mneLSN9pNZ+dui7TuRisHcezx76LZ499FwCWMMd0Giz3YN1rX8G6174CMMfU4vmYDYPaj+d0DZ7TNQBzTC2ej9kwWDuBZ4+twLPHVgDMMbUGtR/PRavxXLQaYI6Z5rwftqpuB7AUAEQkB6AbwKMTlD6lqr/gt3vkS3O+A7d3/goAYNWRL20FMA/MMXWaS1247eqPAgBWb/occ0wpno/Z0CytuBV3AwDW6grmmFI8H7OhOT8Dt3d+EACw6uhXmGNKNUsrbpX3AQDWRo8wxww727fj3gVgp6rumYrOUGLawByzgDlmA3PMBuaYDcwxG5hjNjDHDDvbSeiDAP55kt/dJiIviciPRcS9qjLVUyeYYxYwx2xgjtnAHLOBOWYDc8wG5phhzrfjniIiRQAfAPBHE/x6I4DLVHVARO4H8H0AiyZoYzmA5QBQQhMQhbHbjAYGbZ2L1F2jkampBT9w12mtamorMnQ/2H/E1tbQkLNGRNztaAgA7QAemeDXZ59j0AL0DcRuMx/a9j0q7v1aGKmYmio8M9tZo+URU1uN+/udNVKNP5bf2mZ//L4CgKDmbst7jmiCVuP37eX/zbbvdaf7CcvQkDUA5xhxNiznrezZb2vLsa8AIOw95qzxnWNDaQYqd/xc7DYb1m139gsA9n7sbc6aS7/0qqmtsN9wDuVyprb0xqvdRc+/Ymor197m3l615qyJNAQG/eV46dw8Xvn9v4/d5k29H3P2CwDe9cnnnTUdBff1BQCe2XSDs2bP/U2mti4fuNhZs+9udz4A8PFf+6Gz5rFX73TWRFENOOJxXM23QVsc+6PHdtw33tDrrJGGoqktwyMmSGgcey11TY2mprSh4KxxP8qZmutjogLeMxQAotHH7X5z3Pxa7DZD63GvlrPIyGNbWnNfr6aTsznS7wOwUVUPj/+Fqp5U1YGxrx8HUBCRrgnqHlLVZaq6rICGc+40nbue6AAADPnKsRjYLi7k19Hh3YDHHHk+1kcPDgE+cyw2T3mf6Wf1hN2AxxxnzbRNTMivnhNvAD6vj/mEJy8EADha3g3w+ph6PTgI+MxRmON0czaT0I9gkpfERWS2jL0MJyI3j7XrftqOEnco3A0AE75EwxzT4+DQdoA5pt7h0Y+5MMeUO1TbBTDH1Dvc8wrAHFPv4ODrAHNMvcO6D2COmWZ6O66INAO4G8BHT/vZ7wKAqn4ZwAcBfExEagCGATyo6vO1avKhpjX0hgcB4MSpnzHH9KlFVfSW9wLMMdVCreHY6BO8zDHFalpFb43jatqFYQXH+nYCzDHValEVvSO8PqZdqDUcA6+PWWeahKrqIICZ43725dO+/iKAL/rtGvmWlzzubPwVrB7+5ltvemeO6ZMPCrhr3u9i5d6/YY4plpM87sj9EtaG32GOKZaXAu5seRCrB77OHFMslyvijpv+CGvXfZY5plg+KOCuuR/Fyn1fYI4plpM87pAHsDZ6hDlmGD/9TERERERERInhJJSIiIiIiIgSw0koERERERERJYaTUCIiIiIiIkqM6cZEU0bilx2WvK17PhdnHZjrXiy5U2xzd8m566SpZGtrcNBdZFzsHcO2MpN8Duhsjy2pzrStXZgbrDhrqp22dUlb7jnkrJF/tLU1sNC9YHpuxHZTtkbDwvfhDONaj3ttZSbiPt9e/60OU1NXfdmdY3DwiKktHRkx1ZkYzo/c7ItMTYXdB501QYsxx+O2MguJFLlhx2LbUWRqq9DvrvF5M0KNbG3JiHsxcXOvLNusww0XX9vbhXd9bHlszUUv7DK19fKb1zlrNIi/Fp9SeD1+oXcAuOLrs01t6b4DzpoF/baL1fdeeJ+zpviKu+/eRQoZcYyHxuP+eHf8dRYAZtfc1z0AQOg+h6zjhOm8NVz3AMB0FHp8vGfWVIIsuTq2JHfc8BgNQHTAndH2P19iauuqT21w1pjHVcP1UWtVU1sW+blzbIX7vG0SUigi59hu1Gu8IFfd+yIql21tOeZCAOzXocAwD4gM539C+EooERERERERJYaTUCIiIiIiIkoMJ6FERERERESUGE5CiYiIiIiIKDGchBIREREREVFiOAklIiIiIiKixHASSkRERERERInhJJSIiIiIiIgSw0koERERERERJUZUtT4bFjkKYM+4H3cB6KlDd3xIU98vU9VZPhpijnXFHCeXpr4zx8mlqe/McXJp6jtznFya+s4cJ5emvjPHyaWp7xPmWLdJ6EREZL2qLqt3P85FmvvuW5r3RZr77lua90Wa++5bmvdFmvvuW5r3RZr77lua90Wa++5bmvdFmvvuW5r3RZr7fgrfjktERERERESJ4SSUiIiIiIiIEjPdJqEP1bsD5yHNffctzfsizX33Lc37Is199y3N+yLNffctzfsizX33Lc37Is199y3N+yLNffctzfsizX0HMM0+E0pERERERETZNt1eCSUiIiIiIqIMmxaTUBG5V0S2i8gOEflMvftztkRkt4i8IiKbRWR9vftTL8wxG5hjNjDHbGCO2cAcs4E5ZgNznB7q/nZcEckBeB3A3QD2A3gRwEdUdWtdO3YWRGQ3gGWqmpb1erxjjtnAHLOBOWYDc8wG5pgNzDEbmOP0MR1eCb0ZwA5VfVNVKwAeBvBAnftEZ485ZgNzzAbmmA3MMRuYYzYwx2xgjtPEdJiEzgWw77Tv94/9LE0UwGoR2SAiy+vdmTphjtnAHLOBOWYDc8wG5pgNzDEbmOM0ka93BzLinaraLSIXAVgjIq+p6pP17hSdNeaYDcwxG5hjNjDHbGCO2cAcsyETOU6HV0K7Acw/7ft5Yz9LDVXtHvvvEQCPYvSl/gsNc8wG5pgNzDEbmGM2MMdsYI7ZwByniekwCX0RwCIRWSgiRQAPAniszn0yE5FmEWk99TWA9wF4tb69qgvmmA3MMRuYYzYwx2xgjtnAHLOBOU4TdX87rqrWROSTAFYByAH4mqpuqXO3zsbFAB4VEWB0f35bVVfWt0vJY47ZwByzgTlmA3PMBuaYDcwxG5jj9FH3JVqIiIiIiIjowjEd3o5LREREREREFwhOQomIiIiIiCgxnIQSERERERFRYup2Y6KiNGgJzX4aG/1wbjzjZ18ln/OzPQBarRmaMrbl8bO7/Tjeo6qzfLRVzDdpY6E9tkZHKj42NdaYbT9oW5OzRmqRbZtDZVtdwrzm6PN8NJDA9vyX18+sT9PPvyedYz3GnAtBms9Hr4zHl+V8NI8TkXEsN2COU8B4SMDjkMMcq2KF/QAAIABJREFUs4E5ZsNkOZomoSJyL4AvYPQuUv+gqn8+7vcNAL4B4EYAvQA+rKq749osoRm3BO91bNh2AZKC+8/Qim0ilOvodG8vb5u71w4ddtYEpZKprah8/hOhHj2E17EZAFpF5DM+cmwstOO2Bb8Zu93ozb22DgbuK5VlYg8AI++8wVlTOjRkaks3bzXVmViOaY1/QNWjh/C6bgI85lhCM26Ru+L7FRieoAGAKHSWBI3uJwkAe962tjw+GeLhia+pOB8tOZrHnJERU52JZULrcfJiZjmmHcdzvXI0s+xX47XWtLmcbZywnI/WcSIaMozljv1Qt3G1Hsd9wqyPmbR2/uP9tD8f08x6rJraih9zevQgXo94PmbFWl2xZ6KfO688IpID8HcA7gOwBMBHRGTJuLLfBnBcVa8E8NcAPn9+3SXfVBXbsQlL8U4A2ALmmEqqiu26EUvl5wHmmFo8H7OBOWYDx9Vs4PmYDaoRtkcbsDR4F8AcM83y9OfNAHao6puqWgHwMIAHxtU8AODrY1+vAHCXWN/zRYnowzE0ogVN0gKMvumFOaYQc8wG5pgNzDEbmGM2MMdsGM2xlTleACyT0LkA9p32/f6xn01Yo6o1AH0AZvroIPkxgmGU0Hj6j5hjCo3meMZb1JhjCvF8zAbmmA0cV7OB52M2jGAYJWGOF4JEb0wkIssBLAcwfsCnFDkjx3xbnXtD54rnYzYwx2xgjtnAHLOBOWYDc5zeLK+EdgOYf9r388Z+NmGNiOQBtGP0g8JnUNWHVHWZqi4roOHcekznpAGNKGP49B95ybGY50mdpNEcz7gJB8/HFJqq85E5Jos5ZgPH1Wzg+ZgNDWhEWZnjhcAyCX0RwCIRWSgiRQAPAnhsXM1jAH5j7OsPAvg35f39p5U2dGAYAxjWQWD0hunMMYWYYzYwx2xgjtnAHLOBOWZDGzoxjH4M6wDAHDPN+XZcVa2JyCcBrMLoEi1fU9UtIvI5AOtV9TEAXwXwTyKyA8AxjB4w8cRwK3frEi2W23+H7mUjAEAaG91FhiVhRhtzf0ZaikVbW5blEmLOv0ACLNal2ISnAOBaAH/iJceRCnTP/vhu+Vwaw6jpuR3OGuuyPX7XqrQdh5MJIFiMpdikTwI+cxT3eRRcscDWycNHnSUn7r3G1FTrm4POGs3Zxon8tt3OmmjYthRS0OJea0wr1cn//wCurt6GTSPPAJG/HCUIEDTF92373y929h0Arv78gLsoNK7PePCIs0Qvn2dqSnaPf0L8Z0UD7uMGAIbvdS/l1LI5fntLhu/CphNPAjWP56NPlvHrPMelMzfnb83OqJzMMkFTNq4CzscBQZNxGZpB9zEtDcZXeiLLMWHLUQ2PrXKX2s7t2q4JV3EwCySHxXo9Nqnnxzk0KqGlnHLIYXG0DJtqyZ+Pq7o3mZq5Z85SZ411OTTLPCDs77e1lS+4izyO0Zbzf7Rw4h+bZlOq+jiAx8f97LOnfV0G8CFbT6heuuQSdOESrNUVr6rqnwHMMY2YYzbMKszHrMJ8rD75v5hjis1qXIhZjQuxct8XmGOKcVzNhi65BF1yCdZGjzDHFJsVzMWs4lysqXybOWaYv6c1iIiIiIiIiBw4CSUiIiIiIqLEcBJKREREREREieEklIiIiIiIiBLDSSgRERERERElhpNQIiIiIiIiSgwnoURERERERJQYTkKJiIiIiIgoMfl6bVgkgDQ2OmrE1lZzk7MmOmlqCtV5M501YZNttxW6DzprpKPd1FZQqThrNIxMbcHdlJ0IEDieyzDmaKJqq8u7M5IwPM/OZIgC6tofPcdMTUWDw86a9ldPmNoKjhtOXOPxFQ6XnTVarZna0sEhd431fPRIowiR4++86m9HbG3t6TYU2c7HqOzeZrBzn6mt0LDvEdnO7ebnd7m3d9J48fAtyMX/3vg3BqWSuyjn2NZbjbmft5aGoqmp8Jh7DMhfNs/W1j7DsWr9G93DxNlxnCPR4KC/TY3Yzu2k1XbtsRVar+9UH2q8phlyVGtbvjn6ds+cpd42FZWNg4m1zkCrPh/gTz2+EkpERERERESJ4SSUiIiIiIiIEsNJKBERERERESWGk1AiIiIiIiJKDCehRERERERElBjnJFRE5ovIT0Rkq4hsEZFPTVDzbhHpE5HNY/8+OzXdpXNV1kGsr67Fs5V/BYBrmWM6lXUIG/QJrNNVAHNMLZ6P2VDWQayvrMWzIz8EmGNqMcds4PUxG8o6hA3RT7EuWgkwx0yzrDVSA/AHqrpRRFoBbBCRNaq6dVzdU6r6C/67SD4IAlyVuwFtQSfWVL69DcAnmGP6CASL8Ha0SQfW6grmmFI8H7NBEOCq/FiO5W8xx5RijtnA62M2CASL5LrRHKNHmGOGOV8JVdWDqrpx7Ot+ANsAzJ3qjpFfDdKItqDz1LcRmGMqNUgj2qTj1LfMMaV4PmYDc8wG5pgNvD5mA3O8cFheCX2LiCwAcD2A5yf49W0i8hKAAwD+UFW3xLWlqlDHAuYa2Baht3ywVStVU1v5o+6FyfPFgqmt0LDwvQ7YFqqOLP23L/5bhM8cQ8ei6fVYgNqwaLfW3PlMc95ylHwOuY7OuBLs+Z3Fpk7NX9vvrPnxD/7J1Nb/3969R8l1lne+/z1VXX2/6Gb5Ilu+oRhMhmAQwgYCcbCNMQSHA2FMAjmQzFJgICcZJpkhs87KJDOsdXImmWTCOBOP5wAJFweCGRNncKxLJjEG22BZvss32ci2ZFl3taS+V+3n/NEtRRateh9Ju6t6b30/a2m5u/rxu9/ev3r33m9X1X4vWvcryZpadyzHC/5z899Pkjp27A+1NXHRGcma2q7RUFt6NL8cI7Ku2GG/WgkcWbPYMccCx3Kz2PE+0pZ7sK2Oarom2C/lOB5lJqs271t0sXfr7EwX1YKXAoF9YT09saYqw8maxqL+UFvamn6upvbnUfIdj4l9Vgnur2w0fTyxWiBrKXSt4FnwvB1oq2P5uaGm6i9sjW0zwlt7XD0tWPA2MoFDZng8TrY2x8rAQKguO5i+zkFaeBJqZv2SviXpN9392JnaRknnu/shM7tO0rclrZiljdWSVktSt3pPutM4eXWvS9LFkj5KjsWVe46V4MUeclVvTEqMx8Kr+5REjoVHjuXAdU45MB7LL/RnDTOraXoC+jV3/5/H/tzdD7j7oZmv75BUM7Mls9Td7O4r3X1lzbpPses4UZlnesTvkaS95FhcmWd6RPdKOebYWSHHVsu8oYdfuFXKczyqa877jVfKPNPDk9+VOK4WWuaZHp74R4nxWGhzcp1Dji2XeaZH6ndL5FhqkbvjmqQvSHrC3f/4ODVnzdTJzFbNtLsnz47i1Li7NvkG9WlQknbMVkOO85+7a5M2qE8DEjkWlrvr8a3/S31dSyRyLCx31+NT96rPhiRyLCx31+OT96jPFkjkWFhc55SDu2tT4z6Oq6eByNtx3yrpo5IeNbOHZh77d5KWS5K73yTpg5I+aWZ1SWOSbnBvxwcBcTzD2qOX9bz6NSRJl85kSY4FM53jC+RYcPtHX9T2/Y+qv3upRI6FtT/bpe2NH6l/evJCjgW1P9up7fXnyLHguM4ph/2+S9uzLYzH00ByEuru31PiY8bufqOkG/PqFPK3wJboKvsFSdL67Jub3H3lsTXkOP8tsCW6Sh+UJK33W8mxoBb2Ldc1/+z/liStffRz5FhQC6tLdU3PRyRJa8e+So4FtbB6pq7p+2VJ0tqRL5NjQXGdUw4LK0t1decvSpLWTd5CjiUWvNUVAAAAAACnjkkoAAAAAKBlmIQCAAAAAFqGSSgAAAAAoGUid8edE9bdJXv1q5rWeK0aauvA8vQCtANPDYfaevK3022de9a+UFsDnzg3WbP7p88JtbXke9vTRROToba0NVYWMXFur575rcua1rz6T7bFGrOm97+SJPnBQ6Gm7nj0fydr3rTxQ6G2Fn/g+WSNdcSGknV2Jmt8MphjbFfEuEuTU01LBl7IQk1VhkeTNdc9dV2ord4n0uslNtK7VJJU3T/rnd5fwUfHQ23V9o4layqH0vthTmSNpj+uPb8r1ExjfCJd5LHnhDea90mSsrHYvo+0peBNErPhY9c/n6UmMS7mTCV9PAyppY9NVqvF2ooco3uCa/FZ+m/g9YFYW7Vquq3oMRrA7CyvY5IkC4zZdqgMDoTqsoMH57gnp4f5+SwAAAAAAJQSk1AAAAAAQMswCQUAAAAAtAyTUAAAAABAyzAJBQAAAAC0DJNQAAAAAEDLMAkFAAAAALQMk1AAAAAAQMu0bfXmrFbRxNLepjWNrtgc+cAF1WRN957m2zrss6vuTNasHnop1Na1C34pWbN/RagpLXy0L1lTGW19nJUpqXtnIqfoYu+VQN5T9VBTf7j34mTN7q0LQm0tzraki9xDbanRyK+tPJklF7UfOTs2Hhf0dSdrPn/Rl0NtveuCf5Wsse7APpWUDaT7VR2bCLU1NZhuqzYV61euzGS1zqYlO991fqippXek++/B52q2fzhZU12yONbWnr3pmugx56LlyZLq1u2xtvbFyiKsUlGlp/lzLIscSyRp4VCyJOvuirVVTS9WP3Fm+lwlSV3bXk7W7LosPc4kadnDPcka64tdAyj9VM1XlrV4g23QCP6O7Tj3RVn6uR+S5+8Y6ZMFX2vK0scT60mPs+m20nlbf+w4obFYWV7q23e0doOnOV4JBQAAAAC0TOilMzPbIumgpIakuruvPObnJulPJV0naVTSx9x9Y75dxam6a/N/U0e1S5IuNbMN5FhM3x39ljpUk8ix0O56+s/UUemUyLHQ7hr+BuOxBP7xpS8yHkvge36HqtOXtuRYYHft/Io6KhxXy+5EXgm90t1ff+wTYca7Ja2Y+bda0p/n0Tnk703nfViSNpFjsa3suUYix8J70wW/JJFj4b1p4DqJHAtv1RkfkMix8N6od0jkWHhvWnS9RI6lltfbca+X9GWfdp+kBWZ2dk5to3XIsRzIsRzIsRzIsRzIsRzIsRzIsQSik1CXtNbMHjCz1bP8fJmkF4/6fuvMY5hHzEwbXvyGJL2GHIvM9MD4eokcC80kbXj+ryRyLDSTtOHQnRI5FprJtGHXbRI5Ft6Dulsix0Izkzbs/VuJHEstOgl9m7u/QdMvf3/KzN5+Mhszs9VmtsHMNkxNjpxMEzgFq5Z/RG+58OOS9IxyyrE+So6ttqr7Wl3R814pxxwns/Fc+4i0VRf+st5y8a9KOeY45eTYaqsG3qu3DP68lOd49BbfEhJ689Jf0FvO+kUpz/Go2N22kZ+VulJvtqskciy0VYver7cs+ZBEjqUWmoS6+7aZ/+6UdJukVceUbJN03lHfnzvz2LHt3OzuK919Za0zeHtm5Ka7NnD4y7pyyrGjlxxbrbtyZKmB3HLsrMSWQkB+5mI81owcW627cuQYmN94tOBSCMhNd0f/4S/zG48KLn2D3HT/09ghxwLrrjIeTwfJSaiZ9ZnZwOGvJV0j6bFjym6X9Ms27XJJw+4eXFwNrVDPJlVvHPkrUEXkWEh1n1Ldj6yBSI4FxXgsh+nxOHn4W3IsqHo2pXpGjkXX8DrnxxJgPJ4+Iku0nCnptum7IatD0i3ufqeZfUKS3P0mSXdo+jbJmzV9q+SPz013cbIm66N6cNu3Dn/7GkmfI8fimfRxPTTxj4e/JceCmqyP6MEXGI9FN5mN6cGRvz/8LTkW1GQ2qgd3/6/D35JjQU1oXI/o3um7mJBjYU1mo3pw/52HvyXHEjN3b8uGh6pL/PLe9zYvqsQ+smq96bcu+cFDobb80ouSNY3u0PKqqt7/RLpmyeJQW43dewJFjVBb66a+/sBxbnl9wgYri/zyjnc1rfF6PY9NnZBKb2+yxoP7yyfm5+cI1vutOea42C+vXdu0prLiglhj23clS/b83KtDTQ1tHk3WeDV2nKg9tiXd1ljss3g2MJAuCj5v1gx/Mb8cbZG/uXJV05rqpT8RasuffT6PLkmSssmpZE2lO/ZWqWw8sF+z2NiuLl6UbupA7NyxbvKWfHOsXtO8KPg7Wld6v878kTkX1tkZqmscPJis6TjrzFBb9R3pY45VYr9jrudHW+Rvtnc271cttr98ajJdVKmG2optMAvWpa8hqwsXhppq7NsX22ZArufHQI6nA+uIXfuGVGPP1XXjX2tpju14rp4Ojjce81qiBQAAAACAJCahAAAAAICWYRIKAAAAAGgZJqEAAAAAgJZhEgoAAAAAaBkmoQAAAACAlmESCgAAAABoGSahAAAAAICWyXHl2RPU1SmtOL9pSdYZW8x2ZHlvsmbgyeFQW0//enrh6HPPii1SO/DJs5M1u346XSNJZ9wdWNB6IrCYtSRtjZVFmFWSi6F7I7aoekhgYWxJsoH+dM3YeKitxsREqK7Q3JM52cHRUFPZZPp52L8t9lzt2BfYZiX2tzQP9MsbsQXaLfCciGxvTqTGyMu7Qs1kk1M5dOZwY+ljQHh/BdqKyg6NJGu8nuN+CLLuLlUvuqhpjb/wUqitA+/5Z8mascUWaqvRma478PrY8fLS308/D79zz+2htq67+p8na8bPTp8TJEnrvh6rC7CODlUXndG0Zsf7XxVq64wvPZCsmXp7OmtJ6hhJP6erw2OhtrR7f7LkjofXhZq67tJ3pIss+NrJ3lhZhFUqqvQkrjMtNoY8cO6wjuCleeDcZ709oaay4YPJmvGrfirUVnUyfR7deVnz68Yj/vBrsbq8eOwaAPnglVAAAAAAQMswCQUAAAAAtAyTUAAAAABAyzAJBQAAAAC0DJNQAAAAAEDLMAkFAAAAALRMchJqZpeY2UNH/TtgZr95TM3PmNnwUTW/O3ddxsk4NLVP399xi76/4xZJupQci2nED+o+X6f7fJ1EjoU1kg3r3vHv6N7x70jkWFgjflD3ZWt1X7ZWIsfCGh3ZpQ33fl4b7v28RI6FNdLYr3uGv6V7hr8lkWNhTezdqWf/4o/07F/8kUSOpZZcjMjdn5L0ekkys6qkbZJum6X0bnd/b77dQ176awv11jN/UZJ059bPb5J0rsixcPpsQJfraknSer+VHAuqrzKkK7rfI0laO/ZVciyoPhvQ5XaNJGl99k1yLKjevjO08or/S5J017rfIceC6qsu0FuGPiBJWrP3f5BjQXUtWqqLP/ZbkqRNf/gZciyxE3077jslPevuz89FZ9AygyLHMiDHciDHciDHciDHciDHciDHEku+EnqMGyT91XF+doWZPSzpJUm/5e6PH1tgZqslrZak7kq/7MWXm26sWqmGOjW4tz9dtGd/qK3zbl2RrJnsPzPUVrbrx3bBj1lyf2ewrT3poqmpUFuSFkn6/HF+duI59vY239rYeLRfAVmsbNFQssQOBJ/+Bw7E6iLM0jXu0dbyy9H6VOnuarqxkdedHepU32Pp3/G5j4Sa0uLvL0nWZLVYW2ePjCVrfPhgrLEzFiVLKoHtSZK25jseq4sXN93cHY/8fahbP/uxf5GssXrsudr9xLZkzcFVy0NtDTyQbivbHTheSnrhX70hWbP872LnDj2YY47qVba5+TWX12PH+6E7N6VrasFjYeD4de63E+eDGY1t25M1117/0VBb2vx0sqRrS/hyJ9ccG7t2Nd3YGV8aDnXKpyaTNR3/8FCorYiGB8+1gfPVda97Z2yb+2PjNijX82PyuV9pwy1WItcT1dh1dESjJ/g7BsrqfeHN5joek2qxa3LkI3xUNrNOSe+T9Duz/HijpPPd/ZCZXSfp25J+bDbn7jdLulmShmpnhK+0kZ/MG5I0JOmbs/z4JHJcSo5tkE1fIOSXY3UJObZB/uOR42o7ZFm+OQ5WFpNjG+Q9HgdtETm2AefHcvB6XWI8ltqJ/Onm3ZI2uvuOY3/g7gfc/dDM13dIqplZ+iUMtNzu7CVJGiXHYtutlyVyLLxd41skciy83Qc2S+RYeLsb2yRyLDzOj+UwsvlJiRxL7UQmoR/Wcd6Ka2ZnmU2/L8DMVs20m+v7KpCPlxtbJGnvbD8jx+LYoRckciy87aNPS+RYeC/ve0wix8J7uf4jiRwLj/NjORx8dKNEjqUWejuumfVJulrSrx312Cckyd1vkvRBSZ80s7qkMUk3uMc/2IbWqHtdexrbJenIh5zIsXgaXtde7ZTIsdDq2ZT2TLwokWOh1RuT2nPwOYkcC63uU9pT5/xYdJwfyyGbnNDIc09L5FhqoUmou49IWnzMYzcd9fWNkm7Mt2vIW4d16MqeD2nt2Fcbhx8jx+KpWofeofdpvd9KjgXWUanpnees1p1bP0+OBdZR7dSVr/ttrX3wP5BjgXVYTVf236C1h/6SHAuM82M5VDq79Kp/+zk9/XufIccSa8PtvAAAAAAApysmoQAAAACAlmESCgAAAABomfA6oacLy9Kfa7bgGs7KAoXRz1FH2moHd2l6Lad5xSbSC7n7POx327gnn2O1Q8H9NZne97XtsQWhu/enn/eNWmDBbkmaSC/2rkYjXSPJAr+jB2py18jkh0aallz6Z/8y1NSFT28NbS8iGz6QrOl/fFesrX370zXBfb/srub7SpJsW6xfeTIzWXdX0xofCT5XhwaTNd4dXKC9kv679dSS/lBT1V3pG1nueW2srSVPNd9XkqSe7lBbOhQrC7Pmxyerxl4L8MBT2irBY2Fke1n0NYr0McC6AvnMZ9WKKoMDzWs6a6Gmst2B/dXbG2rLAttsLF0YaqsyNp6seen9gXOopOxgul8/+vn/Fmqr+nuhsvxMTLR4g6c3XgkFAAAAALQMk1AAAAAAQMswCQUAAAAAtAyTUAAAAABAyzAJBQAAAAC0DJNQAAAAAEDLMAkFAAAAALQMk1AAAAAAQMswCQUAAAAAtIy5e3s2bLZL0vPHPLxE0u42dCcPRer7+e5+Rh4NkWNbkePxFanv5Hh8Reo7OR5fkfpOjsdXpL6T4/EVqe/keHxF6vusObZtEjobM9vg7ivb3Y+TUeS+563I+6LIfc9bkfdFkfuetyLviyL3PW9F3hdF7nveirwvitz3vBV5XxS573kr8r4oct8P4+24AAAAAICWYRIKAAAAAGiZ+TYJvbndHTgFRe573oq8L4rc97wVeV8Uue95K/K+KHLf81bkfVHkvuetyPuiyH3PW5H3RZH7nrci74si913SPPtMKAAAAACg3ObbK6EAAAAAgBJjEgoAAAAAaJl5MQk1s2vN7Ckz22xmn213f06UmW0xs0fN7CEz29Du/rQLOZYDOZYDOZYDOZYDOZYDOZYDOc4Pbf9MqJlVJT0t6WpJWyXdL+nD7r6prR07AWa2RdJKdy/KorG5I8dyIMdyIMdyIMdyIMdyIMdyIMf5Yz68ErpK0mZ3f87dJyV9XdL1be4TThw5lgM5lgM5lgM5lgM5lgM5lgM5zhPzYRK6TNKLR32/deaxInFJa83sATNb3e7OtAk5lgM5lgM5lgM5lgM5lgM5lgM5zhMd7e5ASbzN3beZ2VJJ68zsSXf/brs7hRNGjuVAjuVAjuVAjuVAjuVAjuVQihznwyuh2ySdd9T35848Vhjuvm3mvzsl3abpl/pPN+RYDuRYDuRYDuRYDuRYDuRYDuQ4T8yHSej9klaY2YVm1inpBkm3t7lPYWbWZ2YDh7+WdI2kx9rbq7Ygx3Igx3Igx3Igx3Igx3Igx3Igx3mi7W/Hdfe6mX1a0hpJVUlfdPfH29ytE3GmpNvMTJren7e4+53t7VLrkWM5kGM5kGM5kGM5kGM5kGM5kOP80fYlWgAAAAAAp4/58HZcAAAAAMBpgkkoAAAAAKBl2vaZ0M6OXu/pXNC8aHIqvw1WgvPt6fdYJ9oK1Ejy8Yl8tidJkbdN93aHmjo4un23u58R23BznUM93nPWYPOiF6uhtibPCWxvW/Dt44HnjmdZrK0ACz6/QttsR47W7T3W17wo+FwN/Y7Bp32k0Gq1UEs+NRkoCjWVq4Pal1+OlR7v6Wg+Hn0qv+OqdcTGtnd3posOjZ1ib/6JRY+rAdGnxEHfm+N47PJuJcbjaSCaY54fK8p1POaYY2Sseb2Ry7bKIM8ca1193tW3qGlN9WDgek+Sd6fPVx68xqyM19NF0UvMWvr5ZWOx31FZejx6T1eoqXyvc3I8rob2a47X9/NV8Bh9vPNjaBJqZtdK+lNNf4D3/3P3Pzjm512SvizpjZL2SPrn7r6lWZs9nQt0+SX/ovl2n38p0r0Q6ws+8ToDB4hAjSQ1nvlRssaqwYu4evrC0V59adOf7x7erKe23ilJA2b22VxyPGtQV9x8Q/O+/3pikjpj639MT+TO+2xgIiHJX9qRrMkOHQq1FVHp6QnVZYE/TNhrX9P057uHn9FTL+Sco/Xp8u7rmncsONnLDh5M1lhH8O9fgfFRPeesUFONF9N3YPdG8CIuctJIHJx3+8t62h+U8syxY1BvObP5eKy/tD3ZdUmh37G6cHGoqalXn5esqXzvoVBbEZXu2B9yIrzR/I8qu7OX9FT9ASnHHLvVpzfbO0+l26UQzTEbHz/lbe32l/W0HpLyzrFyVfMNBy9AI2OtsXtPqC1VAtcdWRsmtJEL2sT+moscu/oW6XVX/UbT7Q7e9Vyy65I09epzkzX1vtj5sffJ9HWO12JtTZ2Zvk6rPRL7HX0yfZ3mr7646c/n4no1z+Nq6BrGgi9ORP5APk9ZLfAHZknrJm95frbHk3vIzKqS/kzSuyVdKunDZnbsbOdXJe1z91dJ+hNJ/2+oV2gZ90xPvniHLnvVL0nS4yLHQnLP9OTzd+iyFeRYZO6up3yjXm8/LZFjYblnerK+QZfVrpTIsbDcXU/pQb1eb5PIsbDIsRy4Xj19RKbpqyRtdvfn3H1S0tclXX9MzfWS/nLm61slvdPyfD8UTtnwyDb1di1Sb9dCafodZuRYQEdy7F4kkWNhDWuvetSvXuuXyLFn7R/KAAAa90lEQVSwhn2Peo0ci47xWA7kWA5cr54+IpPQZZJePOr7rTOPzVrj7nVJw5Ji79NCS0xMHVRX5yvebkGOBTQxeYAcS2BCY+pW79EPkWMBTWhMXa/8LDU5FtD0eHzFRyrIsYDIsRy4Xj19tPTGRGa2WtJqSequDbVy08jRK3I8c6DNvcHJekWOqZsSYd56RY5VxmNRvSLHV/6BAgVCjuVwdI6dvYmbaGLeYjzOb5FXQrdJOvquEufOPDZrjZl1SBrS9AeFX8Hdb3b3le6+srODJ0MrddUGNDF54OiH8slxKHZDHuSjq3NwbnJU7E51yEeXejSu0aMfyifHCuOxlbrUowkfOfqhXHKsMR5bano8vuLuzORYQHOWY1f/HPUYs5mr61XG4/wTmYTeL2mFmV1oZp2SbpB0+zE1t0v6P2e+/qCk/+153jMdp2ywb5lGJ/ZobGKfNH3faHIsoMG+c8ixBAa1UGM6pLHpCQw5FtSgLdaoH9SYH5LIsbAYj+VAjuXA9erpIzkJnXmv9aclrZH0hKS/dvfHzew/mNn7Zsq+IGmxmW2W9BlJn52rDuPkVKyiS867Ths3f1WSXityLKSKVXXJ8uu08amvSORYWBWr6BK7TA/6dyVyLKyKVXRJx0ptnPoHiRwLq2IVXaLX60HdLZFjYZFjOXC9evqwdv3hYKi21K9Y9MGmNdn+4dy2Z8G1kkJr/wTX9mzs2xfYYH6L2VbPiK3nu2bnnz/g7itjG25usH+Zv/knf61pTeWxZ0Nt+SUXJmvsiVhb2WR6XdVc10HLMcfKQOxzfWsPfCm3HIc6zvArBo+9+dwxumJvZWns2JmssWBblUBd/bXp540kVTY+mazxqcDi31LsuRNZi0/S+sY38huPtsjzWpcwJPi8j6yH7PXgvp+n1vut+ebIOqG5Hlej5m2O83Vtz3kqzxyHOs/05PrL29NrdkpSdWHgfijBdbQj18jRG8ZaYK3zxv79obbm63VOaDxG91dnen3MyHlPkrLR0XRRnvK8Xu2NfbRy7ciXZ80xtpIqAAAAAAA5YBIKAAAAAGgZJqEAAAAAgJZhEgoAAAAAaBkmoQAAAACAlmESCgAAAABoGSahAAAAAICWYRIKAAAAAGgZJqEAAAAAgJbpaNuWTbJqYg5csfy2V63mV5fqd5tYR/B3zNGFF+/SV751U9Oaj9zwqVBbv/LFv0nW/MUvvSfUVvXZbcmaxvCBUFvyLFlS6e+PNTU2lqyZuOKSUFtaEyuL8EZDjf3D+TWY2t7ERKiuEaiz+x6JbdM9VJebrNHa7bVDq/cpTi8WPNf6/Bxr1llTxznnNa1pbNseauvOFzYka9517htDbVVelz7HZA9tirXV25tuazx2vO9Ydna6rT17Q21pJFYWUR/o1J6fWd60ZtH6eqitiZ9s/nyQpKn+2LXcwKM7kzXeVQu1NXH2YLKma+PmUFs+OZWuefUFobb0w1hZiJmsqytREpt32EXNnw+SVB/qibX1g8dCdSGB645KT6xf2ehosqbxUytCbeme4/Ql9n8DAAAAAHDqmIQCAAAAAFqGSSgAAAAAoGWYhAIAAAAAWoZJKAAAAACgZZKTUDM7z8z+wcw2mdnjZvYbs9T8jJkNm9lDM/9+d266i5M11jioH+6+TXfv/JokvZYci2l8bL8e+uHNuv97fyyRY2GN+6ge8Lt0r6+RyLGwyLEcyLEcxrMR3T+2Rt8f/RuJHAtrfGJYG574ku555EaJHEstskRLXdK/dveNZjYg6QEzW+fux96/+253f2/+XUQeTBVdMvhWDXUu1Z0v3fiEpE+RY/FYpaKLX/0eDQwu011rPkuOBWUyrdDrNGgLtd5vJceCIsdyIMdyMJku6VypwepirR35MjkWlFlFP7H8XRrsO0frfvjvybHEkq+Euvt2d9848/VBSU9IWjbXHUO+uqt9GupcevjbTORYSF1dgxoYPBIbORZUl/Vo0BYe/pYcC4ocy4Ecy6Gr0qvB6uLD35JjQXV1Dmiw75zD35JjiUVeCT3CzC6QdJmkH8zy4yvM7GFJL0n6LXd/vPmWO5SdtbhpSWViMtYxz5Il1tcXa6urM7254OK/2rs/WWLV2KLE3kgvQNs4p/n+POIldSqnHJ/70VLd8PFPN91c1yNPh7r1hdU/n6zpeCLWVmNsPF0UWNQ3KhtJL+ob3Wb3D5+Jbja3HEOLOHfEDhfZSGCV8ErseW+19DarS88ItVWPLAofOJZM13m6JrjotTy/HK2rU9XlFzTdXOPZLbF+BXScdWao7uCq9MLePbffH9toYN9XBgZibWWBvAPHXknSWI7jEdOi4zFfueXok1OqP/9iLp1617lvTBcFz2nZQ8e+oHTyIgvaR9Vf3JpbW8oxx0aXdODC5q/ZLO7rDXXq0Dnpa8zJwdi5o3dr+ro2645dr46ck67rfqIn1JYsfbuZ0TODbeWYo5klr2Oi1+Tj56TPMWNLY/t+6P5A3oF9KkkemQ8lrvWOCIztAxcFc7xn9ofDk1Az65f0LUm/6e4HjvnxRknnu/shM7tO0rclrZiljdWSVktSd20oumnkqN6YlKSLJX00jxy7uhfMbYcxq7pPSTnm2K3YCRT5qntdyjPHjsG57TBmxXgsh9zHIzm2Rd45dgwtPPbHaIFGfULKczxa8MUotExoam1mNU1PQL/m7v/z2J+7+wF3PzTz9R2Sama2ZJa6m919pbuv7Ozg4NxqWdbQI8/+tSTtzSvHWo1B3WqZZ3ro0N9LeeZo3XPeb7xS5pke8XukHHPsrIb/uoycZJ7p4cnvSnmORwX/Uo3cZJ7pEd0rkWOhzUWO1V6uc1otyxp6YsNXpDzPj1znzDuRu+OapC9IesLd//g4NWfN1MnMVs20uyfPjuLUuLs2PX+7+rqXSNKO2WrIcf5zdz0+erf6qgskciwsd9cm36A+DUrkWFjursen7lWfDUnkWFjurk3aoD4NSORYWORYDu6uZx7+pnr7l0rkWGqRt+O+VdJHJT1qZg/NPPbvJC2XJHe/SdIHJX3SzOqSxiTd4B750BRaZf+hF7V9zyPq71kqSZfOZEmOBbO/sUMvTW5Wf3WhRI6FNaw9elnPq19DEjkW1v5sl7Y3fqR+WyCRY2FNj8cXGI8FR47lcGDvFu3aulG9A2dJ5FhqyUmou39PUtNPzbr7jZJuzKtTyN/CgeW6euW/lySt2/D7m9x95bE15Dj/Lew4S+9a+KuSpDX7vkCOBbXAlugq+wVJ0vrsm+RYUAurS3VNz0ckSWvHvkqOBbXAlugqfVCStN5vJceCIsdyGFp8od72c/9JkvS9v/035FhisdstAQAAAACQAyahAAAAAICWYRIKAAAAAGgZJqEAAAAAgJaJ3B13TmS1isbPar72Uu+u4Jp3WZYs8cHYOk9Zb2eyptFTC7VVrVaTNVYN/h3A07/j+FmtX3u1Mj6lnidfblrTGBsLtdX59PZkTWNiItRWZH/lKmvk19RIbH/lyaoVVfqbjxHrTI8NScpGRpI1lc7YGLLu9Dp7ExcvDbVV27U7WZNNToXakgfyttb/jW9yYU3Pf+jspjXL/zg9ziTJG+kxtPvqC0Nt/eAP/jxZc+0dq0Jthcb2xefF2gqcOyoHg+PxuVgZTiNmslrz46ZPTYaamrz6smRN55oNobYqvelrhWx0NNSWrOl9K6cFb1qa2leS5PXoMTpWFlEbcS19oPl2fd9wqK2FT/Qna8LXmC/vS9d0pK9DJWnBxFCyJjtwMNSW1+vJmr7n9ofaypNnWej6JKLzHx9O1wSvASL7K0+NfennTdSCv954Sv8/r4QCAAAAAFqGSSgAAAAAoGWYhAIAAAAAWoZJKAAAAACgZZiEAgAAAABahkkoAAAAAKBlmIQCAAAAAFqGSSgAAAAAoGU62rXhRqfp4HnNN9/zXE+sscBCyFOL+0JNTQ2mFwme6o8t/jtYTc/xrTO9OLMkeeB3PLC89XF6va7Gjl3JmojGrt2h7ZVddPHyPDUGunXgyhVNaw4sjz3vz/4vzZ8PkmQXnx9qa+Si9ALad/33m0NtvfvdH07WVLfuCLWVHUoveF3pjx1zlH7ah9VeHtG5f/CDpjVZ1shtewu+cl+o7tpvvDlZ41MTp9qdf2rroU25tZXl1hJOVKUndg2QjY7OcU9Okntux/PONRtyaUfKeX8Frk3CTbXh3BfR6DQdWtb8+qqvN/ZcHV+aroteY9Z29iZrvCt2XTh+RrpfvT/qCrWl8XTJ1KJ03+c1C1zfB+YAkuRTp9qZE2QWqwuM7ejveDy8EgoAAAAAaJnQn0jMbIukg5IakuruvvKYn5ukP5V0naRRSR9z9435dhWnatNXPqdqrUuSLjWzDeRYTN/zO1SdHrrkWGB37b1FHVaTyLHQGI/lQI7lQI7lcM8P/0jVKterZXcir4Re6e6vP/aJMOPdklbM/Fst6c/z6Bzyd/H1n5SkTeRYbG/UOyRyLLw3Df2cRI6Fx3gsB3IsB3Ish8te9ysSOZZaXm/HvV7Sl33afZIWmNnZObWN1iHHciDHciDHciDHciDHciDHciDHEohOQl3SWjN7wMxWz/LzZZJePOr7rTOPvYKZrTazDWa2oT6evrEH8mUyPfe3N0vSa/LKccoDn0BH7h7U3VKeOU4cmpuO4rhMpg3D35HyzFH53dwHcbmPR3JsC3Ish7xz5Hq1PR569C8kxmOpRSehb3P3N2j65e9PmdnbT2Zj7n6zu69095Ud3cE7RyI3r3r/p/UTH/qMJD2jnHKsWXeufUTaSl2pN9tVUp45dvXn2kekrRp6n96y8ANSnjkqePdC5GZOxiM5thw5lsNc5Mj1auu98adWa9UbPiUxHkstNAl1920z/90p6TZJq44p2SbpvKO+P3fmMcwjtf4jy13URY6F1W1HbqVOjgXWXT1yYUOOBcZ4LAdyLAdyLIeursHDX5JjiSUnoWbWZ2YDh7+WdI2kx44pu13SL9u0yyUNu/v23HuLk9aYmlBj8shbZysix0JqeF31f1pUihwLqu5TqmdH1sQjx4JiPJYDOZYDOZZDozGpev3IW2fJscQiS7ScKem26bshq0PSLe5+p5l9QpLc/SZJd2j6NsmbNX2r5I/PTXdxsuqjh7Tlzi8d/vY1kj5HjsUzoXE9onunP6VNjoU1mY3pwQNrD39LjgXFeCwHciwHciyHyclDevSJWw5/S44lZu7elg0PdSzxK/qvb1rTOJTfh8GtFloSVTOT7eaq1VBb2UhrP8xeHRxMF0laM/zFB45zy+sTNlhZ7JfXrm1a41OTTX9+mNU6kzXRtgot8hyUtD77Zm45DnWc4VcMNh+P6op9nqKxY2eyxoJtVQJ19ddeGGqr+uBTyZpscipZM13YSNdUYseJ9Y1v5DcebZG/uXpN86JI36OCz1XrDIztiWLfNGK935pvjvbOPJrCCco9x8pVzYuC12DVM5cmayLHXkmxY1Oex4moyPEkuL/yzLFvxdn+ms83n+Ms/P3Y/TF+9H+k778wtSR2Hlr+N+lP1E31xW79svt16X3/qq/sDrVlh8aSNVs/sDzU1mN/8pn5eVwNnvtC2jQPy8UpXq/mtUQLAAAAAABJTEIBAAAAAC3DJBQAAAAA0DJMQgEAAAAALcMkFAAAAADQMkxCAQAAAAAtwyQUAAAAANAyTEIBAAAAAC3T0bYtV6qywYHmJZOTuW0usli6JKkW2CUdwd02MpKuiS54G1jM1oYGY20Nx8oirKtTlYvOb1qTPfdCqK3K8mXJmmzL1lBbXg8s9pznAsE55ljp6oq1lV4POsyzTNnYeNMam4wtoB3a3lQ9VJdl6f1VezG2gHY90n/PQm2F5NlWkHV1qXrBBU1rGs88l9v2OpadE6rbf8W5yZr+W38Y22hkDA00P7eckKng8z7H8Whdnaqef1HTmuyFbbG2qtVkTTY+EWurkj7OVfr7Qm019qdPRGteeijU1ruWXZasiewHSVJ+h7lpOZ1nssD+ijfWyK+tPOV5Ts5RZWdVff91QdOa6jObQ21ddGv6OqfRHbvG7Hz25UBRLdTW4DPp60d/4aVQW1k9fX5ftqYn1NZjoaogkyx1/W6x1+aqZ56RrPHe7lBbeZ6TQ3OFWmw+5FPpOVh1RfPz1BFPzf4wr4QCAAAAAFqGSSgAAAAAoGWYhAIAAAAAWoZJKAAAAACgZZiEAgAAAABaJjkJNbNLzOyho/4dMLPfPKbmZ8xs+Kia3527LuNkHJraq+9v/6q+v/2rknQpORbTSDase8e/o3vHvyORY2GN+EHdl63VfdlaiRwLi/FYDiN+QPdO3al7p+6UyLGwRvyg7vN1us/XSeRYWCMTe3TPs1/QPc9+QSLHUkveB9rdn5L0ekkys6qkbZJum6X0bnd/b77dQ176a4v01rM/Ikm684X/sknSuSLHwumrDOmK7vdIktaOfZUcC6rPBnS5XSNJWp99kxwLivFYDn02qCtq10qS1k19nRwLqs8GdLmuliSt91vJsaD6uhbrLRf/qiRpzab/hxxL7ETfjvtOSc+6+/Nz0Rm0zKDIsQzIsRzIsRzIsRzIsRzIsRzIscROdBJ6g6S/Os7PrjCzh83s78zstafYL8ytRSLHMiDHciDHciDHciDHciDHciDHEku+HfcwM+uU9D5JvzPLjzdKOt/dD5nZdZK+LWnFLG2slrRakror/fKxsabb9EYW7V5avR6rc0+WWKNxip2ZG6n9KUmZNyRpSNI3Z/nxSeWol3c379dUcN/v2pMs8fpUrK1AjrnKcXvZZPp3zD1H9aZzyvN577Gx7YFNZsMH8ttmns+bQFvZdJ/yy7E2KO+sNd+oBf/2GNlftdgpZLI/x3vgmaVLqsHtZemMvJJuay7GY+OZ55LbjcjzSBh5SjT2D+e2vXctuyxWGBhrHrgGyH08Wp8q3d3Ntzk+nuyXJO3/hfS+GPrqfaG2qoODyZrGgeBxtVJN12Sxc0dqX0kncH70/HLs6lqg6njidwieH20y/TysVNLHOEnyLD0grR7s10SgLnp+DBxXFbgunJPrnOhcIKG+dVu6KHCuktTy61WfmsytrVM9T53IlcG7JW109x3H/sDdD7j7oZmv75BUM7Mls9Td7O4r3X1lZyV9sEH+dk++KEmjueVo5NgOu/WylGOONXXNeZ/x4/LOsbOjb877jB+3u7FNYjwWXu7jkRzbIvccOzmutsPukeckjquldiKT0A/rOC+Jm9lZZtNTfjNbNdNu+mUttNz2ic2StHe2n5FjceyY/ngEORbcDr0gkWPhvVzfIpFj4TEey4HzYzlsP/ikRI6lFnovlZn1Sbpa0q8d9dgnJMndb5L0QUmfNLO6pDFJN7i3+v2QSKn7lPZMbpWk/YcfI8fiaXhde6f/MEiOBdbwuvZqp0SOhVb3Ke1pvCSRY6ExHsuB82M51LNJ7RndIpFjqYUmoe4+ImnxMY/ddNTXN0q6Md+uIW8dVtPPLvmY1uz670fe/E+OxVO1Dr2j+n6tb3yDHAusah16h96n9X4rORZYh9V0Zd8NWjvyZXIsMMZjOXB+LIeOSqd+9uJf15pn/pAcSyzHu0UAAAAAANAck1AAAAAAQMswCQUAAAAAtAyTUAAAAABAy8RWGp8D579mv26+4/amNT//H3871lhgAe29bwgUSVp8/r5kzdvP3hxq68nrlyVrfKg/1JbtTS8A/p0H7gy1VT07VBay4rUHdceaf2hac/WHPhZq69/85VeSNf/pw78Yaqv6zNZkTWM4uBh3YIX2Sn8sRx8bS9ZMXBVcoP3vvhGrC7BKRZWe5mu+WmdnqK3GvvQYso5arF+19CHKzos9oe2ZLckar6cXQp8uDNyEL7xQdawsJHNZajH3wPM5LLDguCR1juT4Swb2vTdy/B0DC8JjbkSPE3kuvp4nd1eWGo9Bix5MH1cbyYpp2dj4qXXmFY1FtxpoKrKvctxe1E9cuFvrb/li05q3/8vVobb6fyN9bfK+Mx8NtfU/bvq5ZM3kYKgpLbvyxXTR710SaqtjX/o654nPxK6Z9PFYWYR1d6l6waua1nhP7Dqn9qezrhzzCt961XdCbV3/k1eli6qx1wyz/em5gq+8NNSWfpB+Hq7Z9mCoqePNO3glFAAAAADQMkxCAQAAAAAtwyQUAAAAANAyTEIBAAAAAC3DJBQAAAAA0DJMQgEAAAAALcMkFAAAAADQMkxCAQAAAAAtwyQUAAAAANAy5u7t2bDZLknPH/PwEkm729CdPBSp7+e7+xl5NESObUWOx1ekvpPj8RWp7+R4fEXqOzkeX5H6To7HV6S+k+PxFanvs+bYtknobMxsg7uvbHc/TkaR+563Iu+LIvc9b0XeF0Xue96KvC+K3Pe8FXlfFLnveSvyvihy3/NW5H1R5L7nrcj7osh9P4y34wIAAAAAWoZJKAAAAACgZebbJPTmdnfgFBS573kr8r4oct/zVuR9UeS+563I+6LIfc9bkfdFkfuetyLviyL3PW9F3hdF7nveirwvitx3SfPsM6EAAAAAgHKbb6+EAgAAAABKbF5MQs3sWjN7ysw2m9ln292fE2VmW8zsUTN7yMw2tLs/7UKO5UCO5UCO5UCO5UCO5UCO5UCO80Pb345rZlVJT0u6WtJWSfdL+rC7b2prx06AmW2RtNLdi7JeT+7IsRzIsRzIsRzIsRzIsRzIsRzIcf6YD6+ErpK02d2fc/dJSV+XdH2b+4QTR47lQI7lQI7lQI7lQI7lQI7lQI7zxHyYhC6T9OJR32+deaxIXNJaM3vAzFa3uzNtQo7lQI7lQI7lQI7lQI7lQI7lQI7zREe7O1ASb3P3bWa2VNI6M3vS3b/b7k7hhJFjOZBjOZBjOZBjOZBjOZBjOZQix/nwSug2Secd9f25M48Vhrtvm/nvTkm3afql/tMNOZYDOZYDOZYDOZYDOZYDOZYDOc4T82ESer+kFWZ2oZl1SrpB0u1t7lOYmfWZ2cDhryVdI+mx9vaqLcixHMixHMixHMixHMixHMixHMhxnmj723HdvW5mn5a0RlJV0hfd/fE2d+tEnCnpNjOTpvfnLe5+Z3u71HrkWA7kWA7kWA7kWA7kWA7kWA7kOH+0fYkWAAAAAMDpYz68HRcAAAAAcJpgEgoAAAAAaBkmoQAAAACAlmESCgAAAABoGSahAAAAAICWYRIKAAAAAGgZJqEAAAAAgJZhEgoAAAAAaJn/H7NcxnmkN9NWAAAAAElFTkSuQmCC\n",
            "text/plain": [
              "<Figure size 1152x576 with 32 Axes>"
            ]
          },
          "metadata": {
            "needs_background": "light"
          }
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "yhg6POpQkCq1"
      },
      "source": [
        "Over a large number of samples, these soft pairs will tend to approximate a noisier version of the implicit coupling. We can thus use stochastic gradient descent to optimize this coupling to have a better score under our objective of interest."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 369
        },
        "id": "4qkQugplixMf",
        "outputId": "7f18df71-8c30-477d-e120-abaa8b7d8291"
      },
      "source": [
        "_, axs = plt.subplots(ncols=2, figsize=(12,6))\n",
        "# True implicit coupling defined by gadget 2\n",
        "axs[0].imshow(g2_init_pq, vmin=0)\n",
        "# Relaxed approximation with gradients\n",
        "axs[1].imshow(jnp.mean(soft_pairs, axis=0), vmin=0)"
      ],
      "execution_count": 28,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "<matplotlib.image.AxesImage at 0x7f5b85819b90>"
            ]
          },
          "metadata": {},
          "execution_count": 28
        },
        {
          "output_type": "display_data",
          "data": {
            "image/png": "iVBORw0KGgoAAAANSUhEUgAAArkAAAFPCAYAAABJfdYtAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAS7UlEQVR4nO3dXYzl913f8c93ZvY5ZO0QEPih9gIu1IrUJmzTJG5TKYkgQEhuKghK0hZR+aaQhyKh0IsicVkhFCohJCsPF8VNpJpcoCgQKkioKoFh44SCvQ4KdvBDHGwjx07WD7uz8+vF7Kpumt09szvn/Nbfeb2kSJ7Z4/P9/bPr7773v2fm1BgjAADQydrsAwAAwG4TuQAAtCNyAQBoR+QCANCOyAUAoJ2NZTzpoWsOjJdfd2QZT31Rz51c+UigmedzKqfHCzX7HKu0cejI2Hf0FSufu/+Zsyufed7YmHOPp144M2Xuuelzxs78Lk578TtIzbrmmvPr67mz38jpree+7fClRO7LrzuSn77zR5fx1Bd18oc3Vz4T6OXu8Yezj7By+46+It//7v+w8rnX/+FTK5953plrD02Ze+BvHp8yN0myPukvbzcn/mHmzKQ/VGxNjOvNSS20f9+UsX/y5H+/4I95uQIAAO2IXAAA2hG5AAC0I3IBAGhH5AIA0I7IBQCgHZELAEA7IhcAgHZELgAA7YhcAADaWShyq+qtVfWlqvpyVX1w2YcC4PLZ2QALRG5VrSf5zSQ/luTWJD9TVbcu+2AA7JydDbBtkTu5r03y5THGA2OM00k+keQdyz0WAJfJzgbIYpF7fZKHX/TxI+c+B8DVx84GyC5+4VlV3V5VJ6rqxHNPvbBbTwvAErx4Z5999tTs4wDsukUi99EkN77o4xvOfe7/Mca4Y4xxfIxx/NC1B3brfADszI539vrhIys7HMCqLBK5f57klqo6VlX7k7wzye8u91gAXCY7GyDJxqUeMMbYrKqfT/KZJOtJPjrGuHfpJwNgx+xsgG2XjNwkGWN8Osmnl3wWAHaBnQ3gHc8AAGhI5AIA0I7IBQCgHZELAEA7IhcAgHZELgAA7YhcAADaEbkAALQjcgEAaGehdzzbqefuT+7/Z8t45ov7ub9+cPVDz/nIPzw2bTbAldj3za18z5+eWvnc3/u9j6985nk//pafmjN4Y33O3CTZPDtn7tlJc5PUkcNzBp/ZnDM3STbn/BobW1tT5l6MO7kAALQjcgEAaEfkAgDQjsgFAKAdkQsAQDsiFwCAdkQuAADtiFwAANoRuQAAtCNyAQBoR+QCANCOyAUAoB2RCwBAOyIXAIB2RC4AAO2IXAAA2hG5AAC0I3IBAGhH5AIA0I7IBQCgHZELAEA7IhcAgHZELgAA7YhcAADaEbkAALQjcgEAaEfkAgDQjsgFAKAdkQsAQDsiFwCAdjaW8aRVa6kDB5bx1Bf10VtvWfnM8z70lf85Ze77b37DlLlAH2cPrOWZY4dWPvetP/mulc8870vvPTJl7q3/+fEpc5MkG+tTxo7nnpsyN0nq4OpbJEmyuTlnbpJsLCXtLqmmTE2yduHJ7uQCANCOyAUAoB2RCwBAOyIXAIB2RC4AAO2IXAAA2hG5AAC0I3IBAGhH5AIA0I7IBQCgHZELAEA7l4zcqrqxqj5bVfdV1b1V9b5VHAyAnbOzAbZtLPCYzSS/OMa4p6q+I8nnq+p/jDHuW/LZANg5OxsgC9zJHWM8Nsa459w/fyPJySTXL/tgAOycnQ2wbUevya2qm5O8OsndyzgMALvHzgb2soUjt6peluR3krx/jPHMt/nx26vqRFWdOD2e380zArBDO9nZmy+cWv0BAZZsocitqn3ZXpZ3jjE++e0eM8a4Y4xxfIxxfH8d3M0zArADO93ZGweOrPaAACuwyHdXqCQfSXJyjPHryz8SAJfLzgbYtsid3NuSvCfJm6rqi+f+9+NLPhcAl8fOBsgC30JsjPG/ktQKzgLAFbKzAbZ5xzMAANoRuQAAtCNyAQBoR+QCANCOyAUAoB2RCwBAOyIXAIB2RC4AAO2IXAAA2rnkO55dlqrUvuU89dXqAz/wxilzP/LQ56bMTZKf+wf/fNpsYPesP7+Va/761Mrn1tmx8pnn/aP/8vUpc5+87XunzE2S77z78Slz69ChKXOTJFtb82ZPMp59ds7ga4/OmVsXfoNHd3IBAGhH5AIA0I7IBQCgHZELAEA7IhcAgHZELgAA7YhcAADaEbkAALQjcgEAaEfkAgDQjsgFAKAdkQsAQDsiFwCAdkQuAADtiFwAANoRuQAAtCNyAQBoR+QCANCOyAUAoB2RCwBAOyIXAIB2RC4AAO2IXAAA2hG5AAC0I3IBAGhH5AIA0I7IBQCgHZELAEA7IhcAgHZELgAA7Wws5VnX1lJHjizlqS/q+edXP3Oyf/f9b5o2+1OP/smUuW+7/oenzIWuxnrl9NH9K5+7/+svrHzm/7Wc3/4u5RV/8dSUuUnyN+/57ilzb/r0s1PmJsm+R/5+ytyt77pmytwkWfv7Z+YMPrM5Z+4YF/whd3IBAGhH5AIA0I7IBQCgHZELAEA7IhcAgHZELgAA7YhcAADaEbkAALQjcgEAaEfkAgDQjsgFAKCdhSO3qtar6gtV9allHgiAK2dnA3vdTu7kvi/JyWUdBIBdZWcDe9pCkVtVNyT5iSQfXu5xALhSdjbA4ndyP5Tkl5JsXegBVXV7VZ2oqhOnt57blcMBcFl2tLPPnDm1upMBrMglI7eq3pbk8THG5y/2uDHGHWOM42OM4/vXDu3aAQFY3OXs7H37jqzodACrs8id3NuSvL2qvpLkE0neVFW/vdRTAXC57GyALBC5Y4xfHmPcMMa4Ock7k/zRGOPdSz8ZADtmZwNs831yAQBoZ2MnDx5jfC7J55ZyEgB2lZ0N7GXu5AIA0I7IBQCgHZELAEA7IhcAgHZELgAA7YhcAADaEbkAALQjcgEAaEfkAgDQzo7e8Wxha2sZRw4t5akvptYmNvva+qS5z8+Zm+Qnb379lLmf+eqfTZmbJD963T+ZNhuWZi3Z2r/6/Tk25u3staefnTN4/745c5Nc/8enp8x9+pbDU+YmySsf/NqUuXXm7JS5STI2N6fMndpgF3D1nQgAAK6QyAUAoB2RCwBAOyIXAIB2RC4AAO2IXAAA2hG5AAC0I3IBAGhH5AIA0I7IBQCgHZELAEA7IhcAgHZELgAA7YhcAADaEbkAALQjcgEAaEfkAgDQjsgFAKAdkQsAQDsiFwCAdkQuAADtiFwAANoRuQAAtCNyAQBoR+QCANCOyAUAoB2RCwBAOyIXAIB2RC4AAO2IXAAA2tlYxpOO9bVsHT28jKe+qLX1ec1e0ybvPW+96bXTZv/Xhz87Ze57brxtylz2iJHU2bHysd+4afW/T5x34Oj+KXP3P316ytwkOfC1U1PmHvrS302ZmyRP/sj3TZn7yj9+ZMrcJBlnz06bfbVxJxcAgHZELgAA7YhcAADaEbkAALQjcgEAaEfkAgDQjsgFAKAdkQsAQDsiFwCAdkQuAADtiFwAANpZKHKr6pqququq7q+qk1X1+mUfDIDLY2cDJBsLPu43kvz+GONfVdX+JIeXeCYAroydDex5l4zcqjqa5I1J/m2SjDFOJzm93GMBcDnsbIBti7xc4ViSJ5J8rKq+UFUfrqoj3/qgqrq9qk5U1Ykzm6d2/aAALGTnO/u0nQ30s0jkbiR5TZLfGmO8OsmpJB/81geNMe4YYxwfYxzft/H/7VMAVmPnO3u/nQ30s0jkPpLkkTHG3ec+vivbCxSAq4+dDZAFIneM8bUkD1fVD5771JuT3LfUUwFwWexsgG2LfneFX0hy57mv0n0gyc8u70gAXCE7G9jzForcMcYXkxxf8lkA2AV2NoB3PAMAoCGRCwBAOyIXAIB2RC4AAO2IXAAA2hG5AAC0I3IBAGhH5AIA0I7IBQCgnUXf1ndn1pKzB5fz1Bcz6uDKZ563PmlurdWkyRNtbU0b/a+P/cspc//TA382ZW6S/Or3vWbabFZjrFU2D69+i41ZizPJ2uaYN3ySb95ydMrcQ4/P+735O//08Slzj901Z26SPPjT3ztl7tZ3HJoyN09ceJG4kwsAQDsiFwCAdkQuAADtiFwAANoRuQAAtCNyAQBoR+QCANCOyAUAoB2RCwBAOyIXAIB2RC4AAO2IXAAA2hG5AAC0I3IBAGhH5AIA0I7IBQCgHZELAEA7IhcAgHZELgAA7YhcAADaEbkAALQjcgEAaEfkAgDQjsgFAKAdkQsAQDsiFwCAdkQuAADtiFwAANoRuQAAtCNyAQBoZ2MZTzrWK6eP7lvGU1/Uvm/WymeeV1tjyty1SXOTpGrS/99bZ+fMnehXf+CfTpv9rvv/dsrcO3/ohilz96RKxoRbHlvr83b2qe9Z/e9RSXLtvc9NmZskh/7u+SlzN556dsrcJKkXTk+Z+8C/uWnK3CQ5+90Hp8xdf2bOr6+MC3eQO7kAALQjcgEAaEfkAgDQjsgFAKAdkQsAQDsiFwCAdkQuAADtiFwAANoRuQAAtCNyAQBoR+QCANDOQpFbVR+oqnur6q+q6uNVNeeNkQG4JDsbYIHIrarrk7w3yfExxquSrCd557IPBsDO2dkA2xZ9ucJGkkNVtZHkcJKvLu9IAFwhOxvY8y4ZuWOMR5P8WpKHkjyW5Okxxh986+Oq6vaqOlFVJ86cPrX7JwXgki5rZ7/wzVUfE2DpFnm5wrVJ3pHkWJLrkhypqnd/6+PGGHeMMY6PMY7v239k908KwCVd1s4+8LJVHxNg6RZ5ucJbkjw4xnhijHEmySeTvGG5xwLgMtnZAFksch9K8rqqOlxVleTNSU4u91gAXCY7GyCLvSb37iR3JbknyV+e+3fuWPK5ALgMdjbAto1FHjTG+JUkv7LkswCwC+xsAO94BgBAQyIXAIB2RC4AAO2IXAAA2hG5AAC0I3IBAGhH5AIA0I7IBQCgHZELAEA7C73j2U5trVdeODqhn2spl7Pg7DljJ15x1saYMrcOHpwyN0myNeea186enTI3Sf7bq26eMvdVn9+aMvd/v2vK2KnOHki+fsv6yucefHLOf09JcvbgnHs8j/2La6bMTZKXP7w5Ze6BffPup62fOjBl7ulr5sxNkv1PPT9l7kNvf+WUuac/duEScicXAIB2RC4AAO2IXAAA2hG5AAC0I3IBAGhH5AIA0I7IBQCgHZELAEA7IhcAgHZELgAA7YhcAADaEbkAALQjcgEAaEfkAgDQjsgFAKAdkQsAQDsiFwCAdkQuAADtiFwAANoRuQAAtCNyAQBoR+QCANCOyAUAoB2RCwBAOyIXAIB2RC4AAO2IXAAA2hG5AAC0I3IBAGhH5AIA0E6NMXb/SaueSPK3l/mvvzLJk7t4nJcC19zfXrve5KV7zTeNMb5r9iFWyc7eMde8N7jml4YL7uylRO6VqKoTY4zjs8+xSq65v712vcnevOa9aC/+PLvmvcE1v/R5uQIAAO2IXAAA2rkaI/eO2QeYwDX3t9euN9mb17wX7cWfZ9e8N7jml7ir7jW5AABwpa7GO7kAAHBFRC4AAO1cNZFbVW+tqi9V1Zer6oOzz7NsVXVjVX22qu6rqnur6n2zz7QqVbVeVV+oqk/NPssqVNU1VXVXVd1fVSer6vWzz7RsVfWBc7+u/6qqPl5VB2efid23l/a2nW1nd9Z1Z18VkVtV60l+M8mPJbk1yc9U1a1zT7V0m0l+cYxxa5LXJfn3e+Caz3tfkpOzD7FCv5Hk98cYP5TkH6f5tVfV9Unem+T4GONVSdaTvHPuqdhte3Bv29l7h53dZGdfFZGb5LVJvjzGeGCMcTrJJ5K8Y/KZlmqM8dgY455z//yNbP9HdP3cUy1fVd2Q5CeSfHj2WVahqo4meWOSjyTJGOP0GOPrc0+1EhtJDlXVRpLDSb46+Tzsvj21t+1sO7u5ljv7aonc65M8/KKPH8keWB7nVdXNSV6d5O65J1mJDyX5pSRbsw+yIseSPJHkY+f+uu/DVXVk9qGWaYzxaJJfS/JQkseSPD3G+IO5p2IJ9uzetrNbs7Mb7eyrJXL3rKp6WZLfSfL+McYzs8+zTFX1tiSPjzE+P/ssK7SR5DVJfmuM8eokp5J0f+3itdm+o3csyXVJjlTVu+eeCnaHnd2end1oZ18tkftokhtf9PEN5z7XWlXty/ayvHOM8cnZ51mB25K8vaq+ku2/2nxTVf323CMt3SNJHhljnL/jc1e2F2hnb0ny4BjjiTHGmSSfTPKGyWdi9+25vW1n29lNtd3ZV0vk/nmSW6rqWFXtz/YLnn938pmWqqoq26/5OTnG+PXZ51mFMcYvjzFuGGPcnO2f4z8aY7T40+KFjDG+luThqvrBc596c5L7Jh5pFR5K8rqqOnzu1/mb0/wLN/aoPbW37Ww7u7G2O3tj9gGSZIyxWVU/n+Qz2f6qvo+OMe6dfKxluy3Je5L8ZVV98dzn/uMY49MTz8Ry/EKSO8+FwANJfnbyeZZqjHF3Vd2V5J5sf0X6F9LsrSLZk3vbzt477OwmO9vb+gIA0M7V8nIFAADYNSIXAIB2RC4AAO2IXAAA2hG5AAC0I3IBAGhH5AIA0M7/AUTd3LvWqxdGAAAAAElFTkSuQmCC\n",
            "text/plain": [
              "<Figure size 864x432 with 2 Axes>"
            ]
          },
          "metadata": {
            "needs_background": "light"
          }
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "r4yefr0hmsed"
      },
      "source": [
        "### Training the gadgets"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "brnHS8TNmub2"
      },
      "source": [
        "So far, we have described how to use the gadgets once we know a particular value for $\\theta$. It remains to see how we can *learn* $\\theta$ to optimize an objective of interest.\n",
        "\n",
        "To this end, we provide a training helper class `CouplingExperimentConfig` which allows either of our gadgets to be trained to minimize a particular objective over a distribution of interest. In order to use this, you must provide\n",
        "- a model definition (either gadget 1 or gadget 2)\n",
        "- a function that generates random pairs of logits according to the distribution of interest (this is $\\mathcal{D}$ from Section 3)\n",
        "- a function that takes a pair of logits, and returns a matrix of scores for each pair of counterfactual samples  (this is $g_{l^{(1)}, l^{(2)}}$ from Section 3)\n",
        "- a flag specifying whether it should pass the `transpose` argument during training (True for gadget 1, False otherwise)\n",
        "- hyperparameters that control the training process, such as the batch size, number of samples per iteration, and optimizer.\n",
        "\n",
        "To show how this works, here is an example of training each of our gadgets based on two distance functions:\n",
        "- $g(x, y) = 0\\text{ if }x=y\\text{ else }1$, which encourages our coupling to be closer to a maximal coupling.\n",
        "- $g(x, y) = (x-y)^2$, which encourages our coupling to minimize the variance of the difference between the sampled indices.\n",
        "\n",
        "In both cases, we take the pair of $p$ and $q$ we have used for the rest of this notebook, and perturb it with a small amount of noise, so that it represents a distribution of intervention pairs."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "BIggP9IwjXzb"
      },
      "source": [
        "def logit_pair_distribution_fn(rng, dim, base_scale=.1, noise_scale=.1):\n",
        "  p_rng, q_rng = jax.random.split(rng, 2)\n",
        "  p_base = jnp.arange(dim) - (dim - 1.0) / 2\n",
        "  q_base = -p_base\n",
        "  p_logits = base_scale * p_base + noise_scale * jax.random.normal(p_rng, (dim,))\n",
        "  q_logits = base_scale * q_base + noise_scale * jax.random.normal(q_rng, (dim,))\n",
        "  return p_logits, q_logits\n",
        "\n",
        "def maximal_coupling_loss_matrix_fn(logits1, logits2):\n",
        "  return 1.0 - jnp.eye(logits1.shape[0])\n",
        "\n",
        "def squared_loss_matrix_fn(logits1, logits2):\n",
        "  seq = jnp.arange(logits1.shape[0]).astype(jnp.float32)\n",
        "  return jnp.square(seq[None, :] - seq[:, None])\n",
        "\n",
        "experiments = []\n",
        "S_dim = 10\n",
        "Z_dim = 100\n",
        "for task_fn in [maximal_coupling_loss_matrix_fn, squared_loss_matrix_fn]:\n",
        "  for gadget in [1, 2]:\n",
        "    ex = experiment_util.CouplingExperimentConfig(\n",
        "      name=f\"Gadget {gadget} example training: {task_fn.__name__}\",\n",
        "      model=(\n",
        "          gadget_1.GadgetOneMLPPredictor(\n",
        "              S_dim=S_dim,\n",
        "              hidden_features=[1024, 1024],\n",
        "              relaxation_temperature=1.0)\n",
        "          if gadget == 1 else\n",
        "          gadget_2.GadgetTwoMLPPredictor(\n",
        "              S_dim=S_dim,\n",
        "              Z_dim=Z_dim,\n",
        "              hidden_features=[1024, 1024],\n",
        "              relaxation_temperature=1.0,\n",
        "              learn_prior=False)\n",
        "      ),\n",
        "      logit_pair_distribution_fn=functools.partial(\n",
        "          logit_pair_distribution_fn,\n",
        "          dim=S_dim,\n",
        "          base_scale=.1,\n",
        "          noise_scale=0.4),\n",
        "      coupling_loss_matrix_fn=task_fn,\n",
        "      inner_num_samples=16,\n",
        "      batch_size=64,\n",
        "      use_transpose=(gadget == 1),\n",
        "      tx=optax.adam(1e-5),\n",
        "      num_steps=2001,\n",
        "      print_every=1000,\n",
        "    )\n",
        "    experiments.append(ex)"
      ],
      "execution_count": 29,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "P7daAX3BpSnq",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "5d15d5ab-6350-48b4-cfa8-678148bf1587"
      },
      "source": [
        "results = []\n",
        "for ex in experiments:\n",
        "  print(\"=\" * 80)\n",
        "  print(ex.name)\n",
        "  print(\"=\" * 80)\n",
        "  results.append(ex.train(jax.random.PRNGKey(42)))\n",
        "  print()"
      ],
      "execution_count": 30,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "================================================================================\n",
            "Gadget 1 example training: maximal_coupling_loss_matrix_fn\n",
            "================================================================================\n",
            "0 [0.2224611682501865/s]: {'loss': 0.9031217098236084}\n",
            "1 [33.60605089417345/s]: {'loss': 0.8943421840667725}\n",
            "2 [34.80692436640056/s]: {'loss': 0.8932392001152039}\n",
            "4 [55.54817733337748/s]: {'loss': 0.905632734298706}\n",
            "8 [80.39685643089898/s]: {'loss': 0.8966636657714844}\n",
            "16 [101.39067268584827/s]: {'loss': 0.8993918895721436}\n",
            "32 [117.50067234424026/s]: {'loss': 0.8806756138801575}\n",
            "64 [126.09115317769741/s]: {'loss': 0.8531259894371033}\n",
            "128 [132.23278171369572/s]: {'loss': 0.7658491134643555}\n",
            "256 [157.70978300143648/s]: {'loss': 0.7173314094543457}\n",
            "512 [172.29725116597388/s]: {'loss': 0.7053846120834351}\n",
            "1000 [177.70340617092123/s]: {'loss': 0.7030209302902222}\n",
            "1024 [154.22337910652647/s]: {'loss': 0.6978725790977478}\n",
            "2000 [175.1152071304068/s]: {'loss': 0.7032450437545776}\n",
            "\n",
            "================================================================================\n",
            "Gadget 2 example training: maximal_coupling_loss_matrix_fn\n",
            "================================================================================\n",
            "0 [0.06030789709325502/s]: {'loss': 0.687383770942688}\n",
            "1 [26.342655805452797/s]: {'loss': 0.6821712851524353}\n",
            "2 [26.86348905427389/s]: {'loss': 0.6825854182243347}\n",
            "4 [38.58534341594451/s]: {'loss': 0.6895071864128113}\n",
            "8 [48.491726424283414/s]: {'loss': 0.6796791553497314}\n",
            "16 [55.47700116560715/s]: {'loss': 0.6906272768974304}\n",
            "32 [58.346553617399756/s]: {'loss': 0.6834739446640015}\n",
            "64 [61.75529727889278/s]: {'loss': 0.6774040460586548}\n",
            "128 [64.12199724627143/s]: {'loss': 0.6704298853874207}\n",
            "256 [64.11760156365136/s]: {'loss': 0.5924275517463684}\n",
            "512 [64.55482885520316/s]: {'loss': 0.5199288129806519}\n",
            "1000 [64.76449748574937/s]: {'loss': 0.4943510591983795}\n",
            "1024 [61.70430714572062/s]: {'loss': 0.5035367012023926}\n",
            "2000 [64.97544841002123/s]: {'loss': 0.48919838666915894}\n",
            "\n",
            "================================================================================\n",
            "Gadget 1 example training: squared_loss_matrix_fn\n",
            "================================================================================\n",
            "0 [0.2376855749763451/s]: {'loss': 18.331106185913086}\n",
            "1 [35.70684033541906/s]: {'loss': 17.67399024963379}\n",
            "2 [36.40699269135288/s]: {'loss': 17.708940505981445}\n",
            "4 [59.55999233188728/s]: {'loss': 17.746984481811523}\n",
            "8 [87.59896200456342/s]: {'loss': 17.80404281616211}\n",
            "16 [115.37036387579468/s]: {'loss': 17.33456802368164}\n",
            "32 [135.10591428047988/s]: {'loss': 17.130428314208984}\n",
            "64 [138.614153093672/s]: {'loss': 16.159027099609375}\n",
            "128 [158.05987464045245/s]: {'loss': 15.20871353149414}\n",
            "256 [160.23610572634107/s]: {'loss': 14.156964302062988}\n",
            "512 [168.3065596247084/s]: {'loss': 14.413787841796875}\n",
            "1000 [169.20956753695305/s]: {'loss': 14.06224250793457}\n",
            "1024 [154.17637346648084/s]: {'loss': 13.972649574279785}\n",
            "2000 [170.6096715333867/s]: {'loss': 13.880552291870117}\n",
            "\n",
            "================================================================================\n",
            "Gadget 2 example training: squared_loss_matrix_fn\n",
            "================================================================================\n",
            "0 [0.0610760644918242/s]: {'loss': 14.064082145690918}\n",
            "1 [27.096037314107782/s]: {'loss': 13.813591957092285}\n",
            "2 [26.942002453767046/s]: {'loss': 13.833362579345703}\n",
            "4 [38.255060857985875/s]: {'loss': 13.818146705627441}\n",
            "8 [48.211914744431574/s]: {'loss': 13.65810775756836}\n",
            "16 [54.067031040528/s]: {'loss': 14.097822189331055}\n",
            "32 [58.902714521068326/s]: {'loss': 13.91343879699707}\n",
            "64 [59.846011848123574/s]: {'loss': 13.534945487976074}\n",
            "128 [62.714754711162406/s]: {'loss': 12.761678695678711}\n",
            "256 [64.02110986589062/s]: {'loss': 10.482722282409668}\n",
            "512 [64.72457645994288/s]: {'loss': 9.658382415771484}\n",
            "1000 [65.03163833364088/s]: {'loss': 8.652952194213867}\n",
            "1024 [61.739540777367736/s]: {'loss': 8.443416595458984}\n",
            "2000 [65.2453540083098/s]: {'loss': 7.784237861633301}\n",
            "\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "cjUYS_-opZF3"
      },
      "source": [
        "gadget_1_maximal_theta = results[0].params\n",
        "gadget_1_maximal = experiments[0].model.bind(gadget_1_maximal_theta)\n",
        "\n",
        "gadget_2_maximal_theta = results[1].params\n",
        "gadget_2_maximal = experiments[1].model.bind(gadget_2_maximal_theta)\n",
        "\n",
        "gadget_1_variance_theta = results[2].params\n",
        "gadget_1_variance = experiments[2].model.bind(gadget_1_variance_theta)\n",
        "\n",
        "gadget_2_variance_theta = results[3].params\n",
        "gadget_2_variance = experiments[3].model.bind(gadget_2_variance_theta)"
      ],
      "execution_count": 31,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 369
        },
        "id": "ClztjoGhp85y",
        "outputId": "17264a7e-8816-4c36-9d70-ce5debeb78e8"
      },
      "source": [
        "g1_maximal_pq = coupling_util.joint_from_samples(\n",
        "    coupling_util.sampler_from_common_random_numbers(gadget_1_maximal.sample, second_kwargs={\"transpose\": True}),\n",
        "    logits_1=p_logits,\n",
        "    logits_2=q_logits,\n",
        "    rng=jax.random.PRNGKey(42),\n",
        "    num_samples=100_000)\n",
        "g2_maximal_pq = coupling_util.joint_from_samples(\n",
        "    coupling_util.sampler_from_common_random_numbers(gadget_2_maximal.sample),\n",
        "    logits_1=p_logits,\n",
        "    logits_2=q_logits,\n",
        "    rng=jax.random.PRNGKey(42),\n",
        "    num_samples=100_000)\n",
        "\n",
        "_, axs = plt.subplots(ncols=2, figsize=(12,6))\n",
        "axs[0].imshow(g1_maximal_pq, vmin=0)\n",
        "axs[1].imshow(g2_maximal_pq, vmin=0)"
      ],
      "execution_count": 32,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "<matplotlib.image.AxesImage at 0x7f5b854b8450>"
            ]
          },
          "metadata": {},
          "execution_count": 32
        },
        {
          "output_type": "display_data",
          "data": {
            "image/png": "iVBORw0KGgoAAAANSUhEUgAAArkAAAFPCAYAAABJfdYtAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAARx0lEQVR4nO3d26vld3nH8c+z956DkzNtoWYSTcRUDSnVOBVNxLYqxFO1UC8UtK20BEo9VhDtjf+AeLgQy+ABrKIX0QuRYFqr1hba6JgI5mQb1OagYlqqpkmTmWS+vdh7IAadWXvP/q3f5FmvFwSyd1bW8/1lzzx5z2/WnlVjjAAAQCdrcx8AAAB2m8gFAKAdkQsAQDsiFwCAdkQuAADtbEzxpAcu2DfOu/DAFE99Ug/cVkufCfTyUB7I0fHwSi2TfefvH2c/+Zylzz12x6NLnwn0crKdPUnknnfhgbzpM38wxVOf1Defvb70mUAvN45/nPsIS3f2k8/JNZ/4o6XPve+qny59JtDLyXa2lysAANCOyAUAoB2RCwBAOyIXAIB2RC4AAO2IXAAA2hG5AAC0I3IBAGhH5AIA0I7IBQCgnYUit6peVlXfrao7q+rdUx8KgJ2zswEWiNyqWk/y4SQvT3J5ktdX1eVTHwyA7bOzATYtcif3eUnuHGN8b4xxNMlnk7xm2mMBsEN2NkAWi9yDSe5+zMf3bH0OgDOPnQ2QXfzGs6q6tqqOVNWRB//n4d16WgAm8Nid/dBPH5r7OAC7bpHIvTfJxY/5+KKtz/2CMcbhMcahMcahAxfs263zAbA9297Z+8/fv7TDASzLIpH7zSSXVdWlVbU3yeuSfGHaYwGwQ3Y2QJKNUz1gjPFIVb05yQ1J1pN8fIxx6+QnA2Db7GyATaeM3CQZY1yf5PqJzwLALrCzAbzjGQAADYlcAADaEbkAALQjcgEAaEfkAgDQjsgFAKAdkQsAQDsiFwCAdkQuAADtLPSOZ9v1wG2Vb165d4qnPqm//I/vLn3mCR+57OmzzQY4HcfueDT3XfXTpc996S33L33mCV++4pzZZgPL4U4uAADtiFwAANoRuQAAtCNyAQBoR+QCANCOyAUAoB2RCwBAOyIXAIB2RC4AAO2IXAAA2hG5AAC0I3IBAGhH5AIA0I7IBQCgHZELAEA7IhcAgHZELgAA7YhcAADaEbkAALQjcgEAaEfkAgDQjsgFAKAdkQsAQDsiFwCAdkQuAADtiFwAANoRuQAAtCNyAQBoR+QCANCOyAUAoJ2NKZ601taytn/fFE99Un/7rGcufeYJH/zB12eZ+/ZLrpplLtDM2vrSR375t89d+swT3nHnbbPM/cDTnzXLXFhF7uQCANCOyAUAoB2RCwBAOyIXAIB2RC4AAO2IXAAA2hG5AAC0I3IBAGhH5AIA0I7IBQCgHZELAEA7p4zcqrq4qr5aVbdV1a1V9bZlHAyA7bOzATZtLPCYR5K8c4xxU1Wdk+RbVfUPY4zbJj4bANtnZwNkgTu5Y4wfjTFu2vr7+5PcnuTg1AcDYPvsbIBN23pNblVdkuQ5SW6c4jAA7B47G1hli7xcIUlSVWcn+VySt48xfv5L/vm1Sa5Nkv111q4dEIDt29bOzoElnw5gegvdya2qPdlclp8eY3z+lz1mjHF4jHFojHFob+3fzTMCsA3b3dl7sm+5BwRYgkX+dIVK8rEkt48x3j/9kQDYKTsbYNMid3KvTvLGJC+uqm9v/fWKic8FwM7Y2QBZ4DW5Y4x/SVJLOAsAp8nOBtjkHc8AAGhH5AIA0I7IBQCgHZELAEA7IhcAgHZELgAA7YhcAADaEbkAALQjcgEAaOeU73i2I1Wpffsmeeoz1Tue/qJZ5l5/7zdmmZskrzh45WyzgV1UlVpfn2HwHDM3feCyy2eZ+7G7/nmWuUny50954WyzYQ7u5AIA0I7IBQCgHZELAEA7IhcAgHZELgAA7YhcAADaEbkAALQjcgEAaEfkAgDQjsgFAKAdkQsAQDsiFwCAdkQuAADtiFwAANoRuQAAtCNyAQBoR+QCANCOyAUAoB2RCwBAOyIXAIB2RC4AAO2IXAAA2hG5AAC0I3IBAGhH5AIA0I7IBQCgHZELAEA7IhcAgHZELgAA7YhcAADa2ZjkWdcqtX/fJE99UhvTXM5C1tdnGfvKpz5vlrlJcsMPj8wy95oLnz3LXOiqaqadPaejNcvYv3ja788yN0muv/cbs8x9xcErZ5kL7uQCANCOyAUAoB2RCwBAOyIXAIB2RC4AAO2IXAAA2hG5AAC0I3IBAGhH5AIA0I7IBQCgHZELAEA7C0duVa1X1c1V9cUpDwTA6bOzgVW3nTu5b0ty+1QHAWBX2dnASlsocqvqoiSvTPLRaY8DwOmyswEWv5P7wSTvSnL8Vz2gqq6tqiNVdeTo8f/blcMBsCPb29njoeWdDGBJThm5VfWqJD8ZY3zrZI8bYxweYxwaYxzau/akXTsgAIvb0c6u/Us6HcDyLHIn9+okr66qHyT5bJIXV9WnJj0VADtlZwNkgcgdY7xnjHHRGOOSJK9L8pUxxhsmPxkA22ZnA2zy5+QCANDOxnYePMb4WpKvTXISAHaVnQ2sMndyAQBoR+QCANCOyAUAoB2RCwBAOyIXAIB2RC4AAO2IXAAA2hG5AAC0I3IBAGhnW+94trD19Yxzz57kqU+mHnxo6TNX2cue+rxZ5n7x3n+dZW6SvOrgc2ebDZNZX8/auecsfew4emzpM0+oqlnmHn/44VnmJskrZ9rZN/zwyCxzk+SaC58922zm504uAADtiFwAANoRuQAAtCNyAQBoR+QCANCOyAUAoB2RCwBAOyIXAIB2RC4AAO2IXAAA2hG5AAC0I3IBAGhH5AIA0I7IBQCgHZELAEA7IhcAgHZELgAA7YhcAADaEbkAALQjcgEAaEfkAgDQjsgFAKAdkQsAQDsiFwCAdkQuAADtiFwAANoRuQAAtCNyAQBoR+QCANCOyAUAoJ2NKZ50rFWOn7N/iqc+qbW1WvrME2rdrxeW5Q8vecFssz93z9dnmfvHFz1/lrmsiPW1jPPOXv7co8eWP3NLPbxvlrlr998/y9wkGTP9977moufOMjdJ/u7ueXb2Gy++epa5/CJlBgBAOyIXAIB2RC4AAO2IXAAA2hG5AAC0I3IBAGhH5AIA0I7IBQCgHZELAEA7IhcAgHZELgAA7SwUuVV1flVdV1V3VNXtVfWCqQ8GwM7Y2QDJxoKP+1CSL40xXltVe5McmPBMAJweOxtYeaeM3Ko6L8mLkvxZkowxjiY5Ou2xANgJOxtg0yIvV7g0yX1JPlFVN1fVR6vqrMc/qKquraojVXXk2CMP7vpBAVjItnf20UftbKCfRSJ3I8mVST4yxnhOkgeSvPvxDxpjHB5jHBpjHNqz4XfGAGay7Z29d93OBvpZJHLvSXLPGOPGrY+vy+YCBeDMY2cDZIHIHWP8OMndVfWMrU+9JMltk54KgB2xswE2LfqnK7wlyae3vkv3e0neNN2RADhNdjaw8haK3DHGt5McmvgsAOwCOxvAO54BANCQyAUAoB2RCwBAOyIXAIB2RC4AAO2IXAAA2hG5AAC0I3IBAGhH5AIA0M6ib+u7PetreeTsvZM89clsVC195glrM82e74pX02svfdEsc9//g3+aZW6S/PUlL5htNssxNtZy7NfOWvrc9YceWfrME9Z+9uAsc2uMWeYmSX5+/yxja32+r/Mbn/LCWeb+yXfvmmVuknzyGRfPNvtM404uAADtiFwAANoRuQAAtCNyAQBoR+QCANCOyAUAoB2RCwBAOyIXAIB2RC4AAO2IXAAA2hG5AAC0I3IBAGhH5AIA0I7IBQCgHZELAEA7IhcAgHZELgAA7YhcAADaEbkAALQjcgEAaEfkAgDQjsgFAKAdkQsAQDsiFwCAdkQuAADtiFwAANoRuQAAtCNyAQBoR+QCANCOyAUAoJ2NKZ70+Hrl4Qv2TPHUJ1fLH3nCJP8hF7D2yKMzTU7q0ePzDD4+3zXn+DzX/M7Lfm+WuUly7b/fMcvcw7/1tFnmrqLje9by4JP3LX3u/v+e7z7LnjFmmTvrnaWHj84yth6dcWePeXb2J5/5lFnmJsk1t/xslrk3XHHuLHNPxp1cAADaEbkAALQjcgEAaEfkAgDQjsgFAKAdkQsAQDsiFwCAdkQuAADtiFwAANoRuQAAtCNyAQBoZ6HIrap3VNWtVXVLVX2mqvZPfTAAdsbOBlggcqvqYJK3Jjk0xrgiyXqS1019MAC2z84G2LToyxU2kjypqjaSHEjyw+mOBMBpsrOBlXfKyB1j3JvkfUnuSvKjJD8bY/z94x9XVddW1ZGqOnLs4f/d/ZMCcEo729kPLPuYAJNb5OUKFyR5TZJLk1yY5KyqesPjHzfGODzGODTGOLRn39m7f1IATmlnO/usZR8TYHKLvFzhpUm+P8a4b4xxLMnnk1w17bEA2CE7GyCLRe5dSZ5fVQeqqpK8JMnt0x4LgB2yswGy2Gtyb0xyXZKbknxn6985PPG5ANgBOxtg08YiDxpjvDfJeyc+CwC7wM4G8I5nAAA0JHIBAGhH5AIA0I7IBQCgHZELAEA7IhcAgHZELgAA7YhcAADaEbkAALSz0DuebddYT46es/x+Xj+6vvSZc9sYT5pt9ly/QqoxZpqc5PhMs2e85sPPvGyWub/77aOzzL3l9TP++JrJ8T3JA7+5/J/RdXyS/wUtNvvReb7Oex+c58d1ktTGPP+PHGsz3k+ruWY/OtPc5IYrzp1l7oX/ds4sc/f+6a/+GruTCwBAOyIXAIB2RC4AAO2IXAAA2hG5AAC0I3IBAGhH5AIA0I7IBQCgHZELAEA7IhcAgHZELgAA7YhcAADaEbkAALQjcgEAaEfkAgDQjsgFAKAdkQsAQDsiFwCAdkQuAADtiFwAANoRuQAAtCNyAQBoR+QCANCOyAUAoB2RCwBAOyIXAIB2RC4AAO2IXAAA2hG5AAC0I3IBAGinxhi7/6RV9yX5zx3+67+e5L928ThPBK65v1W73uSJe81PHWP8xtyHWCY7e9tc82pwzU8Mv3JnTxK5p6OqjowxDs19jmVyzf2t2vUmq3nNq2gVv86ueTW45ic+L1cAAKAdkQsAQDtnYuQenvsAM3DN/a3a9Sarec2raBW/zq55NbjmJ7gz7jW5AABwus7EO7kAAHBaRC4AAO2cMZFbVS+rqu9W1Z1V9e65zzO1qrq4qr5aVbdV1a1V9ba5z7QsVbVeVTdX1RfnPssyVNX5VXVdVd1RVbdX1QvmPtPUquodWz+ub6mqz1TV/rnPxO5bpb1tZ9vZnXXd2WdE5FbVepIPJ3l5ksuTvL6qLp/3VJN7JMk7xxiXJ3l+kr9agWs+4W1Jbp/7EEv0oSRfGmM8M8nvpPm1V9XBJG9NcmiMcUWS9SSvm/dU7LYV3Nt29uqws5vs7DMicpM8L8mdY4zvjTGOJvlsktfMfKZJjTF+NMa4aevv78/mT6KD855qelV1UZJXJvno3GdZhqo6L8mLknwsScYYR8cYP533VEuxkeRJVbWR5ECSH858HnbfSu1tO9vObq7lzj5TIvdgkrsf8/E9WYHlcUJVXZLkOUlunPckS/HBJO9KcnzugyzJpUnuS/KJrd/u+2hVnTX3oaY0xrg3yfuS3JXkR0l+Nsb4+3lPxQRWdm/b2a3Z2Y129pkSuSurqs5O8rkkbx9j/Hzu80ypql6V5CdjjG/NfZYl2khyZZKPjDGek+SBJN1fu3hBNu/oXZrkwiRnVdUb5j0V7A47uz07u9HOPlMi994kFz/m44u2PtdaVe3J5rL89Bjj83OfZwmuTvLqqvpBNn9r88VV9al5jzS5e5LcM8Y4ccfnumwu0M5emuT7Y4z7xhjHknw+yVUzn4ndt3J72862s5tqu7PPlMj9ZpLLqurSqtqbzRc8f2HmM02qqiqbr/m5fYzx/rnPswxjjPeMMS4aY1ySza/xV8YYLX61+KuMMX6c5O6qesbWp16S5LYZj7QMdyV5flUd2Ppx/pI0/8aNFbVSe9vOtrMba7uzN+Y+QJKMMR6pqjcnuSGb39X38THGrTMfa2pXJ3ljku9U1be3Pvc3Y4zrZzwT03hLkk9vhcD3krxp5vNMaoxxY1Vdl+SmbH5H+s1p9laRrOTetrNXh53dZGd7W18AANo5U16uAAAAu0bkAgDQjsgFAKAdkQsAQDsiFwCAdkQuAADtiFwAANr5fy42icEvDdlyAAAAAElFTkSuQmCC\n",
            "text/plain": [
              "<Figure size 864x432 with 2 Axes>"
            ]
          },
          "metadata": {
            "needs_background": "light"
          }
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "NOMiAsG6qOLa"
      },
      "source": [
        "We see that, after optimizing them to be closer to a maximal coupling, both gadgets have adapted to put more probability mass on the diagonal than they did at initialization time."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 369
        },
        "id": "xW9NPHMtqFrV",
        "outputId": "32817098-af81-4586-9dfc-dfac4dc6d52d"
      },
      "source": [
        "g1_variance_pq = coupling_util.joint_from_samples(\n",
        "    coupling_util.sampler_from_common_random_numbers(gadget_1_variance.sample, second_kwargs={\"transpose\": True}),\n",
        "    logits_1=p_logits,\n",
        "    logits_2=q_logits,\n",
        "    rng=jax.random.PRNGKey(42),\n",
        "    num_samples=100_000)\n",
        "g2_variance_pq = coupling_util.joint_from_samples(\n",
        "    coupling_util.sampler_from_common_random_numbers(gadget_2_variance.sample),\n",
        "    logits_1=p_logits,\n",
        "    logits_2=q_logits,\n",
        "    rng=jax.random.PRNGKey(42),\n",
        "    num_samples=100_000)\n",
        "\n",
        "_, axs = plt.subplots(ncols=2, figsize=(12,6))\n",
        "axs[0].imshow(g1_variance_pq, vmin=0)\n",
        "axs[1].imshow(g2_variance_pq, vmin=0)"
      ],
      "execution_count": 33,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "<matplotlib.image.AxesImage at 0x7f5b8611ae10>"
            ]
          },
          "metadata": {},
          "execution_count": 33
        },
        {
          "output_type": "display_data",
          "data": {
            "image/png": "iVBORw0KGgoAAAANSUhEUgAAArkAAAFPCAYAAABJfdYtAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAASiElEQVR4nO3dX4zld3nf8c8zOzte2xuDMVQVu9vaSSjEdUMctobYatKatOFPBGrNhQOkNDe+aRJAaRNoLlAu0igSishFFMklyUXjhEqGVAhR/qgQVe2FxWJQwd6kdW1iL5DGRQ5g43h3dr+9mNnUoWH3zOyc810/83pJljzj4/N8f97ZZ9/785k9NcYIAAB0sjb7AAAAsNdELgAA7YhcAADaEbkAALQjcgEAaGd9GU96+NqD47ojh5bx1Bf0xAMHVz4T6OUv8lROj2dq9jlW6boXrI1jx5byy8EFPfzfD698JtDLhXb2UrbadUcO5RfufcUynvqC/uCmv7nymX/p3Nl5s4E9c9/4z7OPsHLHjq3nEx994crnvuXYbSufCfRyoZ3t5QoAALQjcgEAaEfkAgDQjsgFAKAdkQsAQDsiFwCAdkQuAADtiFwAANoRuQAAtCNyAQBoZ6HIrarXVNUfV9VDVfWuZR8KgN2zswEWiNyqOpDkN5K8NsmNSX6iqm5c9sEA2Dk7G2DLIndyb0ny0Bjj4THG6SQfSPLG5R4LgF2yswGyWOQeSfLYsz4+tf05AC4/djZA9vAbz6rqrqo6UVUnnnzizF49LQBL8Oyd/bWvnZt9HIA9t0jkfjnJsWd9fHT7c3/FGOPuMcbxMcbxw9ce3KvzAbAzO97Z113nD9oB+llks30myUuq6oaq2khyZ5IPL/dYAOySnQ2QZP1iDxhjbFbVTyf5eJIDSX57jPHA0k8GwI7Z2QBbLhq5STLG+GiSjy75LADsATsbwDueAQDQkMgFAKAdkQsAQDsiFwCAdkQuAADtiFwAANoRuQAAtCNyAQBoR+QCANDOQu94tlNPPLiR/3jz0WU89QX94kP3rXzmeb/8va+YMrfWasrcJBlnz04aPObMhaYeOfn8vO2WO1Y+9w9OfXjlM8/7p0dvmTYbWA13cgEAaEfkAgDQjsgFAKAdkQsAQDsiFwCAdkQuAADtiFwAANoRuQAAtCNyAQBoR+QCANCOyAUAoB2RCwBAOyIXAIB2RC4AAO2IXAAA2hG5AAC0I3IBAGhH5AIA0I7IBQCgHZELAEA7IhcAgHZELgAA7YhcAADaEbkAALQjcgEAaEfkAgDQjsgFAKAdkQsAQDsiFwCAdkQuAADtrC/jSasqtbGxjKe+oF+56daVzzzv3//JJ6fM/ec3/MiUuUmSmvR7pHF2zlxoamwczOaxF6187pte/pqVzzzvA499ZMrcO4/N+3UK9ht3cgEAaEfkAgDQjsgFAKAdkQsAQDsiFwCAdkQuAADtiFwAANoRuQAAtCNyAQBoR+QCANCOyAUAoJ2LRm5VHauqT1fVg1X1QFW9fRUHA2Dn7GyALesLPGYzyc+NMe6vqu9K8tmq+uQY48Elnw2AnbOzAbLAndwxxlfHGPdv//03k5xMcmTZBwNg5+xsgC07ek1uVV2f5OYk9y3jMADsHTsb2M8WeblCkqSqDif5YJJ3jDG+8df887uS3JUkh+rqPTsgADu3o5298bwVnw5g+Ra6k1tVB7O1LO8ZY3zor3vMGOPuMcbxMcbxjTq0l2cEYAd2urMPHnRjAuhnkT9doZL8VpKTY4xfW/6RANgtOxtgyyJ3cm9L8pNJbq+qz2//9bolnwuA3bGzAbLAa3LHGP81Sa3gLABcIjsbYIt3PAMAoB2RCwBAOyIXAIB2RC4AAO2IXAAA2hG5AAC0I3IBAGhH5AIA0I7IBQCgnYu+49murK2lDh1aylNf0Di3+pnb3nbja6bM/Q9f+sSUuUly5/f9kylzzz355JS5SZIx5s2GJRkHKqevvWLlc698wfNXPvO8N7/sH0+Z+74vfXLK3CR5x/W3TpsNM7iTCwBAOyIXAIB2RC4AAO2IXAAA2hG5AAC0I3IBAGhH5AIA0I7IBQCgHZELAEA7IhcAgHZELgAA7YhcAADaEbkAALQjcgEAaEfkAgDQjsgFAKAdkQsAQDsiFwCAdkQuAADtiFwAANoRuQAAtCNyAQBoR+QCANCOyAUAoB2RCwBAOyIXAIB2RC4AAO2IXAAA2hG5AAC0I3IBAGhnfSnPuraWOnzVUp76gjbPrn7meTXn9wtvfuUdU+YmyXu/8MEpc//Vy/7RlLlJcu70mUmDJ35t095Yq2xefWDlc889b8KvE9vWvnH1lLnvfOm8/fWLD983Ze4vf/cPTJkL7uQCANCOyAUAoB2RCwBAOyIXAIB2RC4AAO2IXAAA2hG5AAC0I3IBAGhH5AIA0I7IBQCgHZELAEA7C0duVR2oqs9V1UeWeSAALp2dDex3O7mT+/YkJ5d1EAD2lJ0N7GsLRW5VHU3y+iTvX+5xALhUdjbA4ndy35fk55Oc+04PqKq7qupEVZ04fe5be3I4AHZlRzv7zDNPru5kACty0citqh9P8mdjjM9e6HFjjLvHGMfHGMc31q7aswMCsLjd7OyDVxxe0ekAVmeRO7m3JXlDVX0pyQeS3F5Vv7vUUwGwW3Y2QBaI3DHGu8cYR8cY1ye5M8mnxhhvXfrJANgxOxtgiz8nFwCAdtZ38uAxxh8m+cOlnASAPWVnA/uZO7kAALQjcgEAaEfkAgDQjsgFAKAdkQsAQDsiFwCAdkQuAADtiFwAANoRuQAAtLOjdzxb2FplXHVoKU99IfX0Myuf+ZezNzenzB3fenrK3CT51694/ZS5v/c/PzJlbpK8+SW3T5l77plzU+YmScaYN5uVGAeSZ65Z/T2PKw5vrHzmeRvXHJ4yd23SrxVJ8m+/75VT5r77f31mytwk+ZXv+f5ps5nPnVwAANoRuQAAtCNyAQBoR+QCANCOyAUAoB2RCwBAOyIXAIB2RC4AAO2IXAAA2hG5AAC0I3IBAGhH5AIA0I7IBQCgHZELAEA7IhcAgHZELgAA7YhcAADaEbkAALQjcgEAaEfkAgDQjsgFAKAdkQsAQDsiFwCAdkQuAADtiFwAANoRuQAAtCNyAQBoR+QCANCOyAUAoB2RCwBAO+vLeNJxYC1nr95YxlNf0IGVT/x/atbcMSZNTsbm5pS5b/n+10+ZmyT/6eFPT5n7Y0dunjI3SVKTvronfm3vN2ubI1d+7ezK5z79ooMrn3ne+tevnDL3wJk5ezNJ1ib9XP7Vv3vLlLlJ8ksP/7cpc9/z3a+YMpe/yp1cAADaEbkAALQjcgEAaEfkAgDQjsgFAKAdkQsAQDsiFwCAdkQuAADtiFwAANoRuQAAtCNyAQBoZ6HIrarnV9W9VfVHVXWyqn5o2QcDYHfsbIBkfcHH/XqSj40x3lRVG0muWuKZALg0djaw7100cqvqeUl+OMm/SJIxxukkp5d7LAB2w84G2LLIyxVuSPJ4kt+pqs9V1fur6upvf1BV3VVVJ6rqxJkzT+35QQFYyM539mk7G+hnkchdT/KDSX5zjHFzkqeSvOvbHzTGuHuMcXyMcfzgwf9vnwKwGjvf2Rt2NtDPIpF7KsmpMcZ92x/fm60FCsDlx84GyAKRO8b40ySPVdVLtz/16iQPLvVUAOyKnQ2wZdE/XeFnktyz/V26Dyf5qeUdCYBLZGcD+95CkTvG+HyS40s+CwB7wM4G8I5nAAA0JHIBAGhH5AIA0I7IBQCgHZELAEA7IhcAgHZELgAA7YhcAADaEbkAALSz6Nv67shYq2we3ljGU1947oF5zX6gasrcmb9LqWdOT5k7vvX0lLlJ8trvvXXK3A8+9qkpc5PkjqOvmjN40s+pjDljZ1p7ejOHv/i/Vz739N96wcpnnvfETddMmXvtyUlf10nWNg7Ombu+lNRYyC+9/B9OmXvPYx+bMjdJ3nLstmmzLzfu5AIA0I7IBQCgHZELAEA7IhcAgHZELgAA7YhcAADaEbkAALQjcgEAaEfkAgDQjsgFAKAdkQsAQDsiFwCAdkQuAADtiFwAANoRuQAAtCNyAQBoR+QCANCOyAUAoB2RCwBAOyIXAIB2RC4AAO2IXAAA2hG5AAC0I3IBAGhH5AIA0I7IBQCgHZELAEA7IhcAgHZELgAA7YhcAADaWV/Gk471yjPXLuWpL+iKWvnI6ers2XmzZ82deM0Z56aMfdP3/MiUuUnyq4/8lylzf+GGV06Zuy+d3cz42hMrH7txbs7PpyS55ux1U+b++csOT5mbJIdPbUyZu37NlVPmJsna1781Ze7b/v4/mzI3ST7+lY9PmftjL/6BKXMvxJ1cAADaEbkAALQjcgEAaEfkAgDQjsgFAKAdkQsAQDsiFwCAdkQuAADtiFwAANoRuQAAtCNyAQBoZ6HIrap3VtUDVfXFqvr9qjq07IMBsDt2NsACkVtVR5L8bJLjY4ybkhxIcueyDwbAztnZAFsWfbnCepIrq2o9yVVJvrK8IwFwiexsYN+7aOSOMb6c5L1JHk3y1SRfH2N84tsfV1V3VdWJqjpx5pkn9/6kAFzUbnb26XN/sepjAizdIi9XuDbJG5PckOTFSa6uqrd+++PGGHePMY6PMY4fvOLw3p8UgIvazc7eWPOSXaCfRV6u8KNJHhljPD7GOJPkQ0luXe6xANglOxsgi0Xuo0leVVVXVVUleXWSk8s9FgC7ZGcDZLHX5N6X5N4k9yf5wva/c/eSzwXALtjZAFvWF3nQGOM9Sd6z5LMAsAfsbADveAYAQEMiFwCAdkQuAADtiFwAANoRuQAAtCNyAQBoR+QCANCOyAUAoB2RCwBAOwu949lOjbXk9Hetvp/XNg+sfOZsdfbQtNlrB+f8914bY8rcJMnm5pSxtbExZW6SvOvv/IMpc+84eWrK3P9xx5kpc2caZ8/l7JNPrXzu+gtfsPKZ5x14as6P8zWPTBmbJHnqxVfMGfw35u2vg9+8csrcKx+bd82v+3u3T5n77x798JS5b3jdN7/jP3MnFwCAdkQuAADtiFwAANoRuQAAtCNyAQBoR+QCANCOyAUAoB2RCwBAOyIXAIB2RC4AAO2IXAAA2hG5AAC0I3IBAGhH5AIA0I7IBQCgHZELAEA7IhcAgHZELgAA7YhcAADaEbkAALQjcgEAaEfkAgDQjsgFAKAdkQsAQDsiFwCAdkQuAADtiFwAANoRuQAAtCNyAQBoR+QCANBOjTH2/kmrHk/yJ7v811+Y5P/s4XGeC1xzf/vtepPn7jX/7THGi2YfYpXs7B1zzfuDa35u+I47eymReymq6sQY4/jsc6ySa+5vv11vsj+veT/ajz/Ornl/cM3PfV6uAABAOyIXAIB2LsfIvXv2ASZwzf3tt+tN9uc170f78cfZNe8Prvk57rJ7TS4AAFyqy/FOLgAAXBKRCwBAO5dN5FbVa6rqj6vqoap61+zzLFtVHauqT1fVg1X1QFW9ffaZVqWqDlTV56rqI7PPsgpV9fyqureq/qiqTlbVD80+07JV1Tu3v66/WFW/X1WHZp+Jvbef9radbWd31nVnXxaRW1UHkvxGktcmuTHJT1TVjXNPtXSbSX5ujHFjklcl+Zf74JrPe3uSk7MPsUK/nuRjY4yXJXl5ml97VR1J8rNJjo8xbkpyIMmdc0/FXtuHe9vO3j/s7CY7+7KI3CS3JHlojPHwGON0kg8keePkMy3VGOOrY4z7t//+m9n6SXRk7qmWr6qOJnl9kvfPPssqVNXzkvxwkt9KkjHG6THGn8891UqsJ7myqtaTXJXkK5PPw97bV3vbzrazm2u5sy+XyD2S5LFnfXwq+2B5nFdV1ye5Ocl9c0+yEu9L8vNJzs0+yIrckOTxJL+z/b/73l9VV88+1DKNMb6c5L1JHk3y1SRfH2N8Yu6pWIJ9u7ft7Nbs7EY7+3KJ3H2rqg4n+WCSd4wxvjH7PMtUVT+e5M/GGJ+dfZYVWk/yg0l+c4xxc5KnknR/7eK12bqjd0OSFye5uqreOvdUsDfs7Pbs7EY7+3KJ3C8nOfasj49uf661qjqYrWV5zxjjQ7PPswK3JXlDVX0pW/9r8/aq+t25R1q6U0lOjTHO3/G5N1sLtLMfTfLIGOPxMcaZJB9KcuvkM7H39t3etrPt7Kba7uzLJXI/k+QlVXVDVW1k6wXPH558pqWqqsrWa35OjjF+bfZ5VmGM8e4xxtExxvXZ+jH+1Bijxe8Wv5Mxxp8meayqXrr9qVcneXDikVbh0SSvqqqrtr/OX53m37ixT+2rvW1n29mNtd3Z67MPkCRjjM2q+ukkH8/Wd/X99hjjgcnHWrbbkvxkki9U1ee3P/dvxhgfnXgmluNnktyzHQIPJ/mpyedZqjHGfVV1b5L7s/Ud6Z9Ls7eKZF/ubTt7/7Czm+xsb+sLAEA7l8vLFQAAYM+IXAAA2hG5AAC0I3IBAGhH5AIA0I7IBQCgHZELAEA7/xdMIbjXSSMWrwAAAABJRU5ErkJggg==\n",
            "text/plain": [
              "<Figure size 864x432 with 2 Axes>"
            ]
          },
          "metadata": {
            "needs_background": "light"
          }
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "RALZewqusJmu"
      },
      "source": [
        "If they are trained to reduce variance, Gadget 2 learns a coupling that shares some similarity with the inverse CDF coupling, whereas Gadget 1 again pulls a large amount of mass onto the diagonal.\n",
        "\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "pycharm": {
          "name": "#%% md\n"
        },
        "id": "EUdsSGtmcM9Y"
      },
      "source": [
        "##  MDP counterfactual treatment effect\n",
        "In this part we show how to use an MDP as interventional distribution in order to couple between transitions under a behavior policy (e.g. physician) and target policy (e.g. RL policy)\n",
        "\n",
        "Following Oberst and Sontag (2019), we consider a sepsis management simulator and take the following steps:\n",
        "1. Learn an MDP by interacting with the simulator. This MDP represents the \"true\" behavior of sepsis management.\n",
        "2. Train a behavior policy (physician) over the MDP.\n",
        "3. Generate patients trajectories (data) using the behavior policy, and construct an estimated MDP based on the sampled data.\n",
        "4. Learn RL policy over the estimated MDP.\n",
        "\n",
        "The MDP produces the probability $pr(s'| s, a)$ and gives us two interventional distributions depending on whether we choose $a$ according to the physician policy or the RL policy. In counterfactual setting, we sample the counterfactual $s'_{cf}$ conditioned on the observed $s'_{obs}$ (from step 3).\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "fRvmG3zdD296"
      },
      "source": [
        ""
      ],
      "execution_count": 33,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "pycharm": {
          "name": "#%%\n"
        },
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "_WhUgr7AcM9e",
        "outputId": "069a47bb-097a-4f65-ada9-b500a0a06b7a"
      },
      "source": [
        "%cd ..\n",
        "!git clone https://github.com/GuyLor/gumbel_max_causal_gadgets_part2.git"
      ],
      "execution_count": 34,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "/content\n",
            "Cloning into 'gumbel_max_causal_gadgets_part2'...\n",
            "remote: Enumerating objects: 90, done.\u001b[K\n",
            "remote: Counting objects: 100% (90/90), done.\u001b[K\n",
            "remote: Compressing objects: 100% (73/73), done.\u001b[K\n",
            "remote: Total 90 (delta 8), reused 90 (delta 8), pack-reused 0\u001b[K\n",
            "Unpacking objects: 100% (90/90), done.\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "gytXIQGh2zvv"
      },
      "source": [
        "\n",
        "import os\n",
        "os.chdir(\"gumbel_max_causal_gadgets_part2\")"
      ],
      "execution_count": 35,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "EFkBi1ILcln1",
        "outputId": "e8d6d32c-a5a9-434d-a75f-365d13c1bef9"
      },
      "source": [
        "\n",
        "!pip install -r requirements.txt\n"
      ],
      "execution_count": 36,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Requirement already satisfied: torch in /usr/local/lib/python3.7/dist-packages (from -r requirements.txt (line 1)) (1.10.0+cu111)\n",
            "Requirement already satisfied: pandas in /usr/local/lib/python3.7/dist-packages (from -r requirements.txt (line 2)) (1.1.5)\n",
            "Requirement already satisfied: seaborn in /usr/local/lib/python3.7/dist-packages (from -r requirements.txt (line 3)) (0.11.2)\n",
            "Requirement already satisfied: tqdm in /usr/local/lib/python3.7/dist-packages (from -r requirements.txt (line 4)) (4.62.3)\n",
            "Collecting pymdptoolbox\n",
            "  Downloading pymdptoolbox-4.0-b3.zip (29 kB)\n",
            "Requirement already satisfied: jax[cuda11_cudnn82] in /usr/local/lib/python3.7/dist-packages (from -r requirements.txt (line 6)) (0.2.21)\n",
            "Requirement already satisfied: flax in /usr/local/lib/python3.7/dist-packages (from -r requirements.txt (line 7)) (0.3.6)\n",
            "Requirement already satisfied: numpy in /usr/local/lib/python3.7/dist-packages (from -r requirements.txt (line 8)) (1.19.5)\n",
            "Requirement already satisfied: optax in /usr/local/lib/python3.7/dist-packages (from -r requirements.txt (line 9)) (0.0.9)\n",
            "Requirement already satisfied: typing-extensions in /usr/local/lib/python3.7/dist-packages (from torch->-r requirements.txt (line 1)) (3.10.0.2)\n",
            "Requirement already satisfied: pytz>=2017.2 in /usr/local/lib/python3.7/dist-packages (from pandas->-r requirements.txt (line 2)) (2018.9)\n",
            "Requirement already satisfied: python-dateutil>=2.7.3 in /usr/local/lib/python3.7/dist-packages (from pandas->-r requirements.txt (line 2)) (2.8.2)\n",
            "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.7/dist-packages (from python-dateutil>=2.7.3->pandas->-r requirements.txt (line 2)) (1.15.0)\n",
            "Requirement already satisfied: matplotlib>=2.2 in /usr/local/lib/python3.7/dist-packages (from seaborn->-r requirements.txt (line 3)) (3.2.2)\n",
            "Requirement already satisfied: scipy>=1.0 in /usr/local/lib/python3.7/dist-packages (from seaborn->-r requirements.txt (line 3)) (1.4.1)\n",
            "Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.7/dist-packages (from matplotlib>=2.2->seaborn->-r requirements.txt (line 3)) (1.3.2)\n",
            "Requirement already satisfied: pyparsing!=2.0.4,!=2.1.2,!=2.1.6,>=2.0.1 in /usr/local/lib/python3.7/dist-packages (from matplotlib>=2.2->seaborn->-r requirements.txt (line 3)) (2.4.7)\n",
            "Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.7/dist-packages (from matplotlib>=2.2->seaborn->-r requirements.txt (line 3)) (0.11.0)\n",
            "Requirement already satisfied: msgpack in /usr/local/lib/python3.7/dist-packages (from flax->-r requirements.txt (line 7)) (1.0.2)\n",
            "Requirement already satisfied: opt-einsum in /usr/local/lib/python3.7/dist-packages (from jax[cuda11_cudnn82]->-r requirements.txt (line 6)) (3.3.0)\n",
            "Requirement already satisfied: absl-py in /usr/local/lib/python3.7/dist-packages (from jax[cuda11_cudnn82]->-r requirements.txt (line 6)) (0.12.0)\n",
            "Requirement already satisfied: chex>=0.0.4 in /usr/local/lib/python3.7/dist-packages (from optax->-r requirements.txt (line 9)) (0.0.8)\n",
            "Requirement already satisfied: jaxlib>=0.1.37 in /usr/local/lib/python3.7/dist-packages (from optax->-r requirements.txt (line 9)) (0.1.71+cuda111)\n",
            "Requirement already satisfied: toolz>=0.9.0 in /usr/local/lib/python3.7/dist-packages (from chex>=0.0.4->optax->-r requirements.txt (line 9)) (0.11.2)\n",
            "Requirement already satisfied: dm-tree>=0.1.5 in /usr/local/lib/python3.7/dist-packages (from chex>=0.0.4->optax->-r requirements.txt (line 9)) (0.1.6)\n",
            "Requirement already satisfied: flatbuffers<3.0,>=1.12 in /usr/local/lib/python3.7/dist-packages (from jaxlib>=0.1.37->optax->-r requirements.txt (line 9)) (2.0)\n",
            "\u001b[33mWARNING: jax 0.2.21 does not provide the extra 'cuda11_cudnn82'\u001b[0m\n",
            "Building wheels for collected packages: pymdptoolbox\n",
            "  Building wheel for pymdptoolbox (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
            "  Created wheel for pymdptoolbox: filename=pymdptoolbox-4.0b3-py3-none-any.whl size=25656 sha256=41f498d80804a1ecde2421004d5f62369a4ea2f2f5497844a77fcaa0e00d3b54\n",
            "  Stored in directory: /root/.cache/pip/wheels/09/a8/27/a76d688633fa5d71984c288499c2170a8d06726135b8e216fd\n",
            "Successfully built pymdptoolbox\n",
            "Installing collected packages: pymdptoolbox\n",
            "Successfully installed pymdptoolbox-4.0b3\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "aEAl8dnVcM9f"
      },
      "source": [
        "from sepsis_mdp import SepsisMDP\n",
        "import numpy as np\n",
        "import cf.utils as utils\n",
        "from joint_predictor import Coupler\n",
        "from cf import fixed_mechanisms as fm"
      ],
      "execution_count": 37,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "2sW1IKEnu6Yy"
      },
      "source": [
        "### Observations and interventional distributions"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "pycharm": {
          "name": "#%%\n"
        },
        "id": "TcxS1rV_cM9g"
      },
      "source": [
        "# Setup of the sepsis simulator:\n",
        "sep = SepsisMDP()\n",
        "\n",
        "# Load an MDP that was trained over the simulator's states and actions - the 'true' transition distributions of sepsis management\n",
        "true_mdp = sep.load_mdp_from_simulator()\n",
        "\n",
        "# Train a behavior policy over the true MDP using policy iteration algorithm\n",
        "physician_policy = sep.get_physician_policy(true_mdp)\n",
        "\n",
        "# Sample trajectories of patients by interacting with the MDP using the physician policy\n",
        "# Using these trajectories, construct an estimated MDP\n",
        "obs_samples, est_mdp = sep.simulate_patient_trajectories_and_construct_mdp(physician_policy,\n",
        "                                                                           num_steps=20,\n",
        "                                                                           num_samples=20000) \n",
        "\n",
        "# Train a policy over the estimated MDP\n",
        "cf_policy = sep.train_rl_policy(est_mdp)"
      ],
      "execution_count": 38,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "-GIDRz42cM9h"
      },
      "source": [
        "Unlike Oberst and Sontag, we couple between single time steps and therefore, we consider rewards per state (instead of [0,1] rewards at trajectory completion).\n",
        "The state is composed of 6 categorical variables, each with a different number of categories.\n",
        "We sample a Gaussian noise for each category representing its energy.\n",
        "The reward of a given state is obtained by summing the energies associated with its variables."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "pycharm": {
          "name": "#%%\n"
        },
        "id": "x0UcIjb7cM9h"
      },
      "source": [
        "relevant_trajs_and_t = sep.search_for_relevant_tr_t(obs_samples, cf_policy, est_mdp,\n",
        "                                                    num_of_diff_p_q=sep.n_proj_states, num_gt_zero_probs=4)\n",
        "\n",
        "trajectory_idx, time_idx = relevant_trajs_and_t[0]\n",
        "\n",
        "current_state, obs_action, obs_next_state = sep.parse_samples(obs_samples, trajectory_idx, time_idx)\n",
        "cf_action = cf_policy[current_state, :].squeeze().argmax()\n",
        "\n",
        "# get p and q from the MDP\n",
        "behavior_interv_probs = est_mdp.tx_mat[0, obs_action, current_state, :].squeeze().tolist()\n",
        "target_interv_probs = est_mdp.tx_mat[0, cf_action, current_state, :].squeeze().tolist()\n"
      ],
      "execution_count": 39,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "N12Rzsi5cM9i"
      },
      "source": [
        "This function calculates the variance of the treatment effect and compares between couplings with fixed mechanisms inverse-CDF and Gumbel-max to our two learnable mechanisms"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "pycharm": {
          "name": "#%%\n"
        },
        "id": "hPWpjl-_cM9j"
      },
      "source": [
        "\n",
        "def run_comparison(behavior_interv_probs, target_interv_probs, s_prime_obs=None, n=10, seed=0):\n",
        "    # testing\n",
        "    logits_p = np.log(np.array(behavior_interv_probs) + 1e-10).clip(min=-80.0)\n",
        "    logits_q = np.log(np.array(target_interv_probs) + 1e-10).clip(min=-80.0)\n",
        "    batch_size_test = 2000\n",
        "    batch_logits_p = np.tile(logits_p, (batch_size_test, 1)) + noise_scale * np.random.randn(batch_size_test,\n",
        "                                                                                             logits_p.shape[-1])\n",
        "    batch_logits_q = np.tile(logits_q, (batch_size_test, 1)) + noise_scale * np.random.randn(batch_size_test,\n",
        "                                                                                             logits_q.shape[-1])\n",
        "\n",
        "    batch_s_prime_obs = np.tile(s_prime_obs, (batch_size_test, 1)) if counterfactual else None\n",
        "    vars_gm, vars_icdf, vars_gd1, vars_gd2 = [], [], [], []\n",
        "    for i in range(n):\n",
        "        (s_prime_p, s_prime_q), _ = c.gadget_1.sample_from_joint(batch_logits_p, batch_logits_q, batch_s_prime_obs,\n",
        "                                                                 counterfactual=counterfactual, train=False)\n",
        "        vars_gd1.append(utils.compute_variance_treatment_effect(reward_vector, s_prime_p, s_prime_q, batch_s_prime_obs))\n",
        "\n",
        "\n",
        "        (s_prime_p, s_prime_q), _ = c.gadget_2.sample_from_joint(batch_logits_p, batch_logits_q, batch_s_prime_obs,\n",
        "                                                                 counterfactual=counterfactual, train=False)\n",
        "        vars_gd2.append(utils.compute_variance_treatment_effect(reward_vector, s_prime_p, s_prime_q, batch_s_prime_obs))\n",
        "\n",
        "\n",
        "        s_prime_p, s_prime_q = fm.gumbel_max_coupling(batch_logits_p, batch_logits_q, batch_s_prime_obs, counterfactual=counterfactual)\n",
        "        vars_gm.append(utils.compute_variance_treatment_effect(reward_vector, s_prime_p, s_prime_q, batch_s_prime_obs))\n",
        "\n",
        "        s_prime_p, s_prime_q = fm.inverse_cdf_coupling(batch_logits_p, batch_logits_q, batch_s_prime_obs, counterfactual=counterfactual)\n",
        "        vars_icdf.append(utils.compute_variance_treatment_effect(reward_vector, s_prime_p, s_prime_q, batch_s_prime_obs))\n",
        "\n",
        "    return np.mean(vars_gm), np.mean(vars_icdf), np.mean(vars_gd1), np.mean(vars_gd2)\n"
      ],
      "execution_count": 40,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "JD-byR92cM9k"
      },
      "source": [
        "### Training the gadgets\n",
        "Train the gadgets with specific realization of rewards:"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "y_uxndcI-Omc"
      },
      "source": [
        ""
      ],
      "execution_count": 40,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "pycharm": {
          "name": "#%%\n"
        },
        "scrolled": false,
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "DuqAJz4fcM9k",
        "outputId": "e4b53e91-99be-4441-a556-6ec2a1b932d1"
      },
      "source": [
        "\n",
        "noise_scale = 1.0\n",
        "counterfactual = True\n",
        "vars_gm, vars_icdf, vars_gd1, vars_gd2 = [],[],[],[]\n",
        "for t in range(5):\n",
        "    print('='*80)\n",
        "    print(f'Trial {t}: sample new rewards')\n",
        "\n",
        "    c = Coupler(s_dim=sep.n_proj_states, z_dim=20, hidden_features=[1024, 1024], tmp=1.0, seed=t)\n",
        "    reward_vector = sep.randomize_states_rewards()\n",
        "\n",
        "    print('---- Train gadget-1 -----')\n",
        "    c.train_gadget_1(p=behavior_interv_probs, q=target_interv_probs, s_prime_obs=obs_next_state, reward_vector=reward_vector,\n",
        "                     batch_size=64, counterfactual=counterfactual, num_iter=200, noise_scale=noise_scale)\n",
        "    print('---- Train gadget-2 -----')\n",
        "    c.train_gadget_2(p=behavior_interv_probs, q=target_interv_probs, s_prime_obs=obs_next_state, reward_vector=reward_vector,\n",
        "                     batch_size=64, counterfactual=counterfactual, num_iter=200, noise_scale=noise_scale)\n",
        "\n",
        "    gm, icdf, gd1, gd2 = run_comparison(behavior_interv_probs, target_interv_probs, obs_next_state, n=10)\n",
        "    print(f'Gumbel-max: {gm}, inverse-CDF: {icdf}, gadget-1: {gd1}, gadget-2: {gd2}')\n",
        "    vars_gm.append(gm); vars_icdf.append(icdf); vars_gd1.append(gd1); vars_gd2.append(gd2)\n",
        "\n",
        "print('Average over 5 rewards realizations (same p, q)')\n",
        "print(f'Gumbel-max: {np.mean(vars_gm)}, inverse-CDF: {np.mean(vars_icdf)}, gadget-1: { np.mean(vars_gd1)}, gadget-2: {np.mean(vars_gd2)}')\n"
      ],
      "execution_count": 41,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "================================================================================\n",
            "Trial 0: sample new rewards\n",
            "---- Train gadget-1 -----\n",
            "0: {'loss': 6.815032958984375}\n",
            "1: {'loss': 6.83551025390625}\n",
            "2: {'loss': 6.994156837463379}\n",
            "4: {'loss': 6.778452396392822}\n",
            "8: {'loss': 6.953495502471924}\n",
            "16: {'loss': 6.9138689041137695}\n",
            "32: {'loss': 6.835074424743652}\n",
            "64: {'loss': 6.287680149078369}\n",
            "128: {'loss': 5.896256446838379}\n",
            "200: {'loss': 5.839803218841553}\n",
            "---- Train gadget-2 -----\n",
            "0: {'loss': 2.373525381088257}\n",
            "1: {'loss': 2.2369132041931152}\n",
            "2: {'loss': 2.1144003868103027}\n",
            "4: {'loss': 1.7802098989486694}\n",
            "8: {'loss': 1.1901711225509644}\n",
            "16: {'loss': 0.7079426050186157}\n",
            "32: {'loss': 0.5204874277114868}\n",
            "64: {'loss': 0.46064454317092896}\n",
            "128: {'loss': 0.43597525358200073}\n",
            "200: {'loss': 0.4254891574382782}\n",
            "Gumbel-max: 1.6229719088618957, inverse-CDF: 7.687002460750731, gadget-1: 0.9543093113932342, gadget-2: 0.021393688133929636\n",
            "================================================================================\n",
            "Trial 1: sample new rewards\n",
            "---- Train gadget-1 -----\n",
            "0: {'loss': 0.6267673969268799}\n",
            "1: {'loss': 0.6025773286819458}\n",
            "2: {'loss': 0.6116480827331543}\n",
            "4: {'loss': 0.6427431106567383}\n",
            "8: {'loss': 0.6093651056289673}\n",
            "16: {'loss': 0.6066969633102417}\n",
            "32: {'loss': 0.6027437448501587}\n",
            "64: {'loss': 0.5549300909042358}\n",
            "128: {'loss': 0.5022635459899902}\n",
            "200: {'loss': 0.4973168969154358}\n",
            "---- Train gadget-2 -----\n",
            "0: {'loss': 0.7107163071632385}\n",
            "1: {'loss': 0.7120217084884644}\n",
            "2: {'loss': 0.6897112131118774}\n",
            "4: {'loss': 0.6421249508857727}\n",
            "8: {'loss': 0.5438899993896484}\n",
            "16: {'loss': 0.37780094146728516}\n",
            "32: {'loss': 0.28407013416290283}\n",
            "64: {'loss': 0.2638854384422302}\n",
            "128: {'loss': 0.2557735741138458}\n",
            "200: {'loss': 0.2534683644771576}\n",
            "Gumbel-max: 0.06929503907013328, inverse-CDF: 0.5930441721526314, gadget-1: 0.03338696036648546, gadget-2: 0.04212520560329554\n",
            "================================================================================\n",
            "Trial 2: sample new rewards\n",
            "---- Train gadget-1 -----\n",
            "0: {'loss': 2.5762240886688232}\n",
            "1: {'loss': 2.5481934547424316}\n",
            "2: {'loss': 2.5147852897644043}\n",
            "4: {'loss': 2.5088233947753906}\n",
            "8: {'loss': 2.5528452396392822}\n",
            "16: {'loss': 2.555717945098877}\n",
            "32: {'loss': 2.5118608474731445}\n",
            "64: {'loss': 2.3755581378936768}\n",
            "128: {'loss': 2.1253859996795654}\n",
            "200: {'loss': 2.1082639694213867}\n",
            "---- Train gadget-2 -----\n",
            "0: {'loss': 2.395547389984131}\n",
            "1: {'loss': 2.339913845062256}\n",
            "2: {'loss': 2.297952651977539}\n",
            "4: {'loss': 2.0331649780273438}\n",
            "8: {'loss': 1.4600975513458252}\n",
            "16: {'loss': 0.9771777391433716}\n",
            "32: {'loss': 0.7859787940979004}\n",
            "64: {'loss': 0.7007331848144531}\n",
            "128: {'loss': 0.6741129159927368}\n",
            "200: {'loss': 0.6681551337242126}\n",
            "Gumbel-max: 0.773329142398105, inverse-CDF: 2.8497924077361865, gadget-1: 0.6092196966477113, gadget-2: 0.4091588129126439\n",
            "================================================================================\n",
            "Trial 3: sample new rewards\n",
            "---- Train gadget-1 -----\n",
            "0: {'loss': 5.906597137451172}\n",
            "1: {'loss': 5.993771553039551}\n",
            "2: {'loss': 5.758336067199707}\n",
            "4: {'loss': 6.171683311462402}\n",
            "8: {'loss': 6.061111927032471}\n",
            "16: {'loss': 5.954693794250488}\n",
            "32: {'loss': 5.937180519104004}\n",
            "64: {'loss': 5.258135795593262}\n",
            "128: {'loss': 4.856616497039795}\n",
            "200: {'loss': 4.751044273376465}\n",
            "---- Train gadget-2 -----\n",
            "0: {'loss': 3.8443615436553955}\n",
            "1: {'loss': 3.7762885093688965}\n",
            "2: {'loss': 3.3654990196228027}\n",
            "4: {'loss': 2.8285388946533203}\n",
            "8: {'loss': 1.7072570323944092}\n",
            "16: {'loss': 0.8873077034950256}\n",
            "32: {'loss': 0.5000956058502197}\n",
            "64: {'loss': 0.3169873356819153}\n",
            "128: {'loss': 0.2776269316673279}\n",
            "200: {'loss': 0.266557514667511}\n",
            "Gumbel-max: 2.7203931347806107, inverse-CDF: 7.884715517252104, gadget-1: 1.7166531677920365, gadget-2: 0.22186028940070318\n",
            "================================================================================\n",
            "Trial 4: sample new rewards\n",
            "---- Train gadget-1 -----\n",
            "0: {'loss': 2.807694911956787}\n",
            "1: {'loss': 2.7649433612823486}\n",
            "2: {'loss': 2.8492047786712646}\n",
            "4: {'loss': 2.847357749938965}\n",
            "8: {'loss': 2.843879222869873}\n",
            "16: {'loss': 2.8169002532958984}\n",
            "32: {'loss': 2.796600580215454}\n",
            "64: {'loss': 2.5771541595458984}\n",
            "128: {'loss': 2.3061718940734863}\n",
            "200: {'loss': 2.266900062561035}\n",
            "---- Train gadget-2 -----\n",
            "0: {'loss': 2.410926342010498}\n",
            "1: {'loss': 2.2446250915527344}\n",
            "2: {'loss': 2.092498540878296}\n",
            "4: {'loss': 1.7587807178497314}\n",
            "8: {'loss': 1.1251004934310913}\n",
            "16: {'loss': 0.63548743724823}\n",
            "32: {'loss': 0.30186885595321655}\n",
            "64: {'loss': 0.20435461401939392}\n",
            "128: {'loss': 0.1735151708126068}\n",
            "200: {'loss': 0.16532368957996368}\n",
            "Gumbel-max: 0.8051984927699867, inverse-CDF: 3.6378718192340016, gadget-1: 0.484763760787723, gadget-2: 0.09996214648674212\n",
            "Average over 5 rewards realizations (same p, q)\n",
            "Gumbel-max: 1.1982375435761463, inverse-CDF: 4.53048527542513, gadget-1: 0.759666579397438, gadget-2: 0.15890002850746288\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "EhS8-MK1vqjK"
      },
      "source": [
        "### Comparison to fixed causal mechanisms (inverse-CDF, Gumbel-max)"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 345
        },
        "id": "CZ4OYN7VcM9m",
        "outputId": "b4a55519-4bb7-48b7-de95-c9033a71dc98"
      },
      "source": [
        "utils.plot_mdp_variances(vars_gm, vars_icdf, vars_gd1, vars_gd2, cf=counterfactual,\n",
        "                         generalized=noise_scale > 0)"
      ],
      "execution_count": 42,
      "outputs": [
        {
          "output_type": "display_data",
          "data": {
            "image/png": "iVBORw0KGgoAAAANSUhEUgAAAd8AAAFICAYAAAARY/SjAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAgAElEQVR4nO3dd7wcZb3H8c+XhAChhwQCSggIomBDooh4BQWRCyoWQCxgRMEGWNBYgYDt3lBEBUu4QAQLvYg0KQZEpQTBAkoRAhI5kIQaSM/v/vE8m0wmu+fsSfbMnrPn+3699nXOzjw789uZ3f3NPGVGEYGZmZlVZ7V2B2BmZjbYOPmamZlVzMnXzMysYk6+ZmZmFXPyNTMzq5iTr5mZWcWcfG3AkzRc0g8kPSJpsaTp7Y6plSTtJikkjW93LANV3n5TStOmS5ranoh6v35JG0t6RtKhfRjWSlHyZ0lntTuWgcLJ1/qMpPGSPlfBqr4MHAGcB4wH+mydkj7nJGht8i1gJtDvElykC0ZMBA6W9Jo2hzMgDG13ANbRxgNjgVP6eD1vA/4WEV/q4/VASuzTgSkVrMv61rbAgLjKkKQXA4cAR0XEonbHU09E/DrXOn0d2L/N4fR7PvO1AUnSWpJqB4+jgSfbGY9VQ9K6rVpWRMyPiAWtWl4f+wTpQOFX7Q6kBz8H9pU0ut2B9HdOvgOMpGGSJki6S9ILuQ1omqTDS+XGSjpH0uOS5kv6l6TvSBpeKjdFUt2j/3I7WV5mSJoo6R2Sbpc0T9Jjkk4oJEPyEfCuwBb5NbXHboUy2+QYH5O0ILeBnSBp7XoxShol6UxJjwPPAx/OsW8J7FpYx8T8uj0lnSfpQUlzJT0t6beSdm3wfreWdJakR3M8/5F0maQda9sD2KK0rpA0tt72Kix3fJ33vpmkk/J+fCpvx3skfVnSkHrxrSpJ75P0l7yuRyQdK2mPeu3JktaQ9DVJd+fyT0u6XNIOpXJL26MlfTSXny/pYUkTGsQxTtIlkmblsvdK+nrx85PLTc2fia0kXSjpSeDZPG+1/JqbJHXl/fWIpB9L2qjJ7bFcm2v+XEc3j7GFsutL+l9JD+T3MFPSryRtVWc9m0s6X+m7+mzeji9pJsaC/YFpEfFEadnF7X+EpPvy/rpP0hG9XMcKJG0o6fS8r57P+2TH2r6p85KrgNWBd6/qujudq50HEEnDgGuA3YDfko4y5wGvBN4LnJrLbQHcBqwP/Ai4P7/mq8AuknZfxaqrvYFPAz8BzgT2Bb4IPAV8J5f5HPBdYCTw+cJr/5Fj3BG4AXga+CkwA3g1cGSOcdeIWFha77VAF/BNYG3gHuAg4HvALODbudxf89/xwAjgbOBR4EXAx4HrJb0lIn5fW7CkccD1pB+OM4C/59fuCrwRuKPBuiC1w/XWq0j77BLgX3m9ewH/A2xFOtNpGUnvJ501/Qs4DlgEfAR4Z52yqwNXk973OaTP1frAocAfJL05IqaVXvZJYBPStnsa+DDwv5IejYhfFpa9D3Ax8ABwEqnGYmfgeOA1rFhduQ5wI/AHUnXmxnn6MOBLwEXAZaSDsdcBHwPeJGnHlTirrcVVtGaOcyjwXH4P6wN/BMaQPv93A5uSvhO3ShoXEQ/nshsANwGbk74v95A+U78D1momKEmbkKrIf9BNsSNINUA/zXF+APiBpBERcVwz66mz3tVJvzevI30ObiHto+uA2Q1e9mdgPun35icrs95BIyL8GCAPYAKp6uk7deatVvj/F7nc3qUyJ+TpHytMm0LuL1FnmQFMKTwfm6c9D4wtTBcpWT1Wev1UYHqDZf8F+Cewbmn6e/I6xpdjBH7eYFnTgal1pq9dZ9ompOR5ZZ345wGv6mHb1l1Xve1VmD4+z9utMG0tQHXKngMsBjYtTNutvE16+bkZSjq4eRzYsDB9HeDBOtv783na20vLWQ94pPj+C7H9B1i/MH046aDkT4Vpa5IOnm4ChpaWXVtncRtNzdO+Vec9CVirzvSP5dcc0NO+6W5fFtZxLrAEeE9h+veBucCrS+W3IJ2ZF78z38nr/mip7Cl5esP1F8q+JZc9ss682vZ/DnhxYfow0gH4wuL0Xn5uDsvLPq40/XN5eqPv9gOkPhi9XudgerjaeWD5EOns8vjyjIhYAqk6DngXcGdEXFkq9l3yD8kqxnFpREwvrDtIR/KjJa3T04slvZJ05vdLYA1JI2sP4GZSct+zzktP7E2QEfF8YZ3r5OrIxcCtwE6Foq8BtgfOioi/UlLbtq0UEXPzdqs1JYzI7/8aUnPQuBaubkdgM1JSeKoQwxzqn518mHRgdEdp3wwj1T68SVL5rO2siHimsOwXSGdK2xTKvI108HMWsEFp2bXPalP7PZK5AJKGSNogL+eGXGSn8mtWwjeB9wNfiYhL8rpE+h7eBMwovYfnSe+5+B7eTTroObu07P/tRRyj8t/u+jX8IiIerT2JdNb/PdKB1wq1G016N+n7clJp+o/J1f8NzGZZDYU14GrngWUb4K6ImNdNmVGkM5q7yzMi4klJj5GqNVfFg3Wm1aqhNgLm9PD6l+e/x+VHPZvUmXZfz6Etk9vVvg28HdigNLvYzl1LEHf2ZvmrIrdvfgU4GNiadJZVtGELV7dl/ntvnXn1pr2cdGbeXXX6SODfheeNPhPF9tfafj+zm+WW9/vMiHi6XkFJBwBHATuQqu2LVmn7SfoIqZr7jIiYVJg1ivSe9qTx9ikerG0F3B4Ri4sFIuIxSXXfVx21z2r5M1L0jzrT7inEsDK2ItVmLZdoI2K+pAdpvI3FAOlF3k5Ovtaos1V3n43F3czr7geiXOYkUttiPU+VJ+SzqabkM/CbSG3DpwB/I1XNLSG1fb+12WW1QL1teTLLxiZ/G3iCVEX4WtJZUTtrpUTaXl/opkw58XT3mSguF1Jb7V0Nyvyn9LzuPpf0XtK2uw34LOlAYB4whPSZWuntp9Qx7nTSWfSnyrPz3+vo3dnrqqht6xEVrW9VjWDl+kEMKk6+A8t9wMskrRER8xuUmUlKMtuXZ0jakNQxpPjD92SeNyIiitVaq3p2DI2Pfu/PfxdHxHUtWE89u5OqWg+JiOUuSiDpW6WytTPqZi4O0N0R/ZPU/4Gsty0PAm6KiANLsW3dRAy9NT3/3bbOvHrT7ied4d3Q4ir32n5/vgX7/SBSsn1L8aBM0stWZaGStiV1vHoQ2C9W7PQ3k9ShbL0m38ODwDaShhTPfiVtyoq1MY3UarG26abMy+tM264Qw8p4ENhT0nrFs19Ja5A+0yscIOd5m5O2oXXDbb4Dyy9IVT3fKM/IbVG19snLgR0k7VUq9hXSPr+kMK2WePYolT2qBfHOATasxVZwJ6mD0ycbDM0YKmlVj/JrP3TLrVvSnqzYHvgX0g/cIZLqHbQUlzGHxmcg9wE7qzCcKx/wfLRBfOXY1mb5nuGtMg14DBif46mtbx1SL+Wys0k9Z+ue+ebetyvjGtIZ/lfq7V+lsdvNjuNdTDoQWvoblvfTCt+NZuU+AVeQakf2KbaP1+Tv1y+A10var8Fyiu2dl5Gq0g8uFftys3FFxEzS5/MN3RT7kNKFOGoxDCN9lhYDv2l2XSWXkWoSyr8FnyJ1vqtnB1LfgBtXcp2Dhs98B5bvkzpPfEPS60jDjeaRznK3ZVkC/Rqpc8ulkn5E6n34ZlLnkZuAnxWW+StSj8zJ+azhSdKQl5EtiPcW4B3AqZL+SPohuCEinpB0EKla76+SasM1hpPaP99LqhqesgrrvpnUs/YkpfGZj5LObA8iVam+slYwIkLSR0lDjW6TVBtqtAFpWMjVwA8L7+ljkr5JamdbAlyeO3edShr+dYOkc/LrDwUeJiWzoguBT0g6j1SFuQnpCkaNhnCsQGmM6q7AlsUOcGURsUjSF0lJo/b+FpF6Yc8mtQkXz+i/T/r8nCDpraT99CxpaM3u5DPOZuMsxPG8pIOBS4F7835/gLSdXkba7+8h9XLuyYXA+0jb+myWjS0d3u2ruvcj4CWkTmg7S9q5NP+SvJ+/DuwCnC/pfNJnYgGpt/PepGFp4/NrJgEfBE7Pw+vuJvVQ3pnU675ZFwBHS9o0Ih6rM/8+0jCnn5Bqvj5IGiL0zYhY2jafq9R/B/wsIsbXWU7RWaQez8dI2hL4Eym57k8aslYvf+xNaj65tPm3Nki1u7u1H717kIZrfJ30JZ5HqgK7Hfh0qdyWpGErT5B+GB4kJdnhdZa5E2kc5TzSD8Jk0g9io6FGE+ssY2KeN7YwbThp3OfjLDtT2a0wfwvSD930HONs0g/Xd4HNC+Wm0GA4VJ4/nfpDjV5FSpxPkX6QpgL/1Wh5pAOYn5OS9gJS++OlwGsLZTYmjS19kpR4y+/5S6RkO5+UnA+h/lCj4aShXw/n7X4/qWZid1Yc+rNbeVqefgeph+0GTX529ieNgZ5PGjJ0LMuGdpWH5gwljbm+Pa/j+RzjL4A9e4qtu/0GvCJv5xl5Oz9OGjd7NDCiUG4qDYaz5PmHkjoVzSOd2U8m1UrUG1bU41Ajlg1tavQof7aPJh3Izc2fr3+Q2op3Kq1nDOlg4dn8uJyU5Jdbfw/7bjNSUjuqNH3p9s/76/68f+8HPltnOe/M5b/d5HpHkL7Ds/NnYCqpJ37dfUP6nbmgmWUP9ofyBjOzASRXH88k/YgeuwrLOYo0lGfniLilVfFZ6+Wz2j2BbSO3RRfOZD8aEVOaWMbJpES9dSzfx6O3sUwlHYyMLUzbl9TWu2NENOpQZ5nbfM0Gpj1IyXdSTwVh6VjiIaVp6wCfIZ3V/LnlEVqrHUMa5lSvD0Gz3k46YGvptdBze/tE4Gwn3ua4zddsAIqIC0jtgM3aCrhK0rnAQ6Re7x8hNU98KgbODQYGrUjXdV5/FZexQofCVohUhbpDjwVtKSdfs8FhJqlj0IdI7daLSO2VX4mI89sZmNlg5DZfMzOzirnN18zMrGIDqtp55MiRMXbs2HaHYWZm1qM77rhjVkSMqjdvQCXfsWPHMm1a+TaiZmZm/Y+khxvNc7WzmZlZxZx8zczMKubka2ZmVjEnXzMzs4o5+ZqZmVWs8uSb79X6FUn3S5ov6VFJ36s6DjMzs3Zpx1CjKcBbgeOAfwKbA9u1IQ4zM7O2qDT5StqLdEP3V0fEPVWu28zMrL+outr5EOAGJ14zMxvMqk6+OwH3STpV0rOSXpB0saTNKo7DzMysbapu8x0NjAf+AhwIrEu6Gfglkt4QdW6xJOkw4DCAMWPGVBepWcmECRPo6upi9OjRTJrU1D3szczqqjr5Kj/2jYjZAJIeA24kdcK6vvyCiJgMTAYYN26c739obdPV1cWMGTPaHYaZdYCqq52fAv5WS7zZzcAC3OPZzMwGiaqT7z9IZ75lApZUHIuZmVlbVJ18fwO8UtLIwrQ3A6uT2oHNzMw6XtXJdzIwG7hc0jslfRA4B7guIm6uOBYzM7O2qDT5RsSzpI5VTwHnAqeROlkdUGUcZmZm7VT55SUj4gFg76rXa2Zm1l/4rkZmZmYVc/I1MzOrmJOvmZlZxZx8zczMKubka2ZmVjEnXzMzs4o5+ZqZmVXMydfMzKxiTr5mZmYVc/I1MzOrmJOvmZlZxZx8zczMKubka2ZmVjEnXzMzs4o5+ZqZmVXMydfMzKxiTr5mZmYVc/I1MzOrmJOvmZlZxZx8zczMKubka2ZmVjEnXzMzs4o5+ZqZmVXMydfMzKxiTr5mZmYVc/I1MzOrmJOvmZlZxZx8zczMKubka2ZmVjEnXzMzs4pVnnwljZcUdR6frDoWMzOzdhjaxnW/FZhbeP5guwIxMzOrUjuT7+0RMaeN6zczM2sLt/mamZlVrJ3J91+SFkm6V9In2hiHmZlZpdpR7fwYcDRwGzAEOBD4iaThEfG9NsRjZmZWqcqTb0RcA1xTmHSVpDWBb0j6fkQsKZaXdBhwGMCYMWOqC9TMzKyP9Jc23wuBEcDY8oyImBwR4yJi3KhRoyoPzMzMrNX6S/KN0l8zM7OO1V+S737ALODhdgdiZmbW1ypv85V0Eamz1V9JHa7enx9Hltt7zczMOlE7ejvfCxwCbA4IuAc4OCLOaUMsZmZmlWtHb+evAV+rer1mZmb9RX9p8zUzMxs0nHzNzMwq5uRrZmZWMSdfMzOzijn5mpmZVczJ18zMrGJOvmZmZhVz8jUzM6uYk6+ZmVnFnHzNzMwq5uRrZmZWMSdfMzOzijn5mpmZVawdtxQ0W84jx7+y3SE0ZdGTI4ChLHry4QER85hj/tbuEMysAZ/5mpmZVczJ18zMrGJOvmZmZhVz8jUzM6uYk6+ZmVnFnHzNzMwq5uRrZmZWMSdfMzOzijn5mpmZVczJ18zMrGJOvmZmZhVz8jUzM6tYw+Qr6beSti08l6RjJI0ulXu1pPv6MkgzM7NO0t2Z7x7A+qWyxwKblcqtCbykxXGZmZl1rN5WO6tPojAzMxtE3OZrZmZWsbYmX0kvkjRHUkhap52xmJmZVWVoD/PfKGlk/n81IIBdSp2uXrYK6z8BmAOsvQrLMDMzG1B6Sr4n15n2/TrTorcrlvRmYC/gO6QkbGZmNih0l3y37KuVShoC/BA4Hni6r9ZjZmbWH3WXfHcFroiI2X2w3k8CawCnAR/qg+WbmZn1W911uDqLPhi/K2kj4JvAFyJiYauXb2Zm1t91l3z7akzvt4FbIuLKZgpLOkzSNEnTZs6c2UchmZmZVafSoUaStgcOAY6XtIGkDYDhefb6ktYqvyYiJkfEuIgYN2rUqCrDNTMz6xM99Xb+uKS9mlhORMQ3myi3DbA68Kc68x4FzgA+3sRyzMzMBqyeku/+wKImlhOkdtye3Ay8pTRtL+DLwN7Ag00sw8zMbEDrKfm+PSJua9XKImIWMLU4TdLY/O/vI2JOq9ZlZmbWX/nazmZmZhVrSfKVtPrKvjYipkSEfNZrZmaDRXfJ92FgfqOZSnaX9H/A4y2PzMzMrEM1bPONiLqXl5T0BuADpM5YmwBPAr/qk+jMzMw6UE8drgCQ9EpSwj0Q2AJYAAwDvgCcFhHN9Ig2MzMzuql2lrSVpK9L+jtwF3AUcDdwMGm8roA7nXjNzMx6p7sz3wdI43dvBT4BXBQRTwFIWr+C2MzMzDpSTx2uBLwC2A14o6SmqqnNzMyssYbJN3e4eiMwBdgduBx4XNLp+XlUEaCZmVmn6Xacb0TcEhFHAi8C9gQuBd4HXJiLHCppXN+GaGZm1lmaushGRCyJiOsi4mOk4UXvAc7Pf2+V9I8+jNHMzKyj9PoKVxGxMCIui4gPABsDBwH3tzwyMzOzDrVKl5eMiBci4pcR8a5WBWRmZtbpfGMFMzOzijn5mpmZVczJ18zMrGJOvmZmZhXr1RWrJG0H7AhsDpwZEV2StgYej4jn+iJAMzOzTtPsXY3WAc4kXWBjUX7d1UAX8B3gEeCLfRSjmZlZR2m22vlk0qUm9wDWJV3zueZKYK8Wx2VmZtaxmq12fi/w2Yj4naQhpXkPk+7xa2ZmZk1o9sx3LWB2g3nrAotbE46ZmVnnazb53g4c3GDefsAfWxOOmZlZ52u22vlo4FpJ1wEXkG4nuLekz5OS75v7KD4zM7OO0+xdjX5PuofvGsCppA5XxwFbAXtExO19FqGZmVmHaXqcb0T8AfgvSWsBGwJPR8QLfRaZmZlZh2p2nO+6wDoR8VhEzAXmFuZtCjwXEXP6KEYzM7OO0uyZ7xnAM8ChdeZNBNYHDmxRTGZmZh2t2d7ObwauaDDvStzhyszMrGnNJt/1gUbtu/NIbcBmZmbWhGaT7/3APg3m7Q38qzXhmJmZdb5m23x/CPxE0gJgCvAYsCnwEeAzwKf6JDozM7MO1FTyjYjTJW0CfBX4QmHWPOAbEXF6M8uRtF9+/bbA2qTrQp8DTIqIBb0J3MzMbKDqzTjfb0n6IbAzsBHpWs9/iohnerG+jYAbgBOAp4HXk3pLjwYO78VyzMzMBqymky9ATrRXr+zKIuKnpUm/k7Qe8BlJR0RErOyyzfrayDWXAIvyXzOzldd08pW0JmlI0YuBNUuzIyJ+vJIxzAaGreRrzSrzxVc93e4QzKxDNHuFqzcBFwGjGhQJoOnkm+8JvAbwWuBI4Mc+6zUzs8Gi2aFGPwAeBHYA1oiI1UqPIb1c7/P58XvgRuBLvXy9mZnZgNVs8t0WmBgRf4mIhS1Y7xuB/wKOAvYl3SmpLkmHSZomadrMmTNbsGozM7P2arbN96+kHsktERF/zv/eLGkW8DNJJ0XEChfriIjJwGSAcePGuWrazMwGvGbPfD8FfF7Srn0QQy0Rb9kHyzYzM+t3mj3zvRYYDtyQr3L1XLlARGy8kjHskv8+tJKvNzMzG1CaTb6nkXo0rxJJVwPXAXcDi0mJ9yjgvHpVzmZmZp2o2ctLTmzR+m4HxgNjgUWkHtRfBX7SouWbmZn1e726wtWqioijgaOrXKeZmVl/05srXO0MfAx4KSte4YqIeH0L4zIzM+tYTfV2lvQ24CbSpSXfBMwE5gCvJt0s4e99FaCZmVmnaXao0fHA94F98vOjI+KtpLPghcDU1odmZmbWmZpNvtsBVwFLSL2e1waIiIdJtwT8el8EZ2Zm1omaTb7zgNXyzQ8eA15SmPcsqTrazMzMmtBsh6u/kK7vfC1wPfBVSTOABaQq6b/1TXhmZmadp9kz31NYdpGNr5HuSHQN8DtgY+AzrQ/NzMysMzV7kY0rC//PkLQjsDWwFvDPiFjQR/GZmZl1nJW6yEZu+72/xbGYmZkNCg2Tr6RPAxdExMz8f3ciIn7c2tDMzMw6U3dnvqcC00gX1Gh4s/ssACdfMzOzJjRMvhGxWr3/zczMbNX0mFQlrSnpdElvqCIgMzOzTtdj8o2IecCB1LmZgpmZmfVes9XJNwBv6ctAzMzMBotmhxqdBvyfpLWBK4HHWXbRDQAi4p4Wx2ZmZtaRmk2+V+e/X8iPYuJVfj6khXGZmZl1rGaTr6uczczMWqTZy0ve2NeBmJmZDRa9vrykpNWo0/M5Il5oSURmZmYdrqnezkq+LOkBYCHwXJ2HmZmZNaHZoUZHAl8BziB1sPo26T6+9wHTgcP6IjgzM7NO1GzyPRQ4FpiUn18aEccB2wP/BLbpg9jMzMw6UrPJd0vgrohYTKp23gAgIpYAPwI+0jfhmZmZdZ5mk+9sYJ38/yPADoV5GwJrtTIoMzOzTtZsb+c/AK8jXd3ql8BESSOABcBngOv7JjwzM7PO0zD5Slo9IhbmpxOBF+X/v0Oqdh5POuO9Fjii70I0MzPrLN2d+T4u6SLgV8DvIuJegIiYD3w2P8zMzKyXumvz/SXwDtKZ7X8knSJpp2rCMjMz61wNk29EHE6qan47cAVwEPBHSQ9K+rakV1YUo5mZWUfptrdzRCyJiOsi4uPAJsC+pM5XhwN3Sfq7pK9J2qqCWM3MzDpCs0ONiIhFEfGbiDgI2BjYn3SBjdqVrnokaX9Jv5Y0Q9IcSXdI+sBKRW5mZjZA9frGCtkOwJuBN5IS+PQmX/cF4CHg88AsYG/gl5JGRsQPVzIWMzOzAaXp5CtpB+BA4ABgDPAEcD7wq4i4pcnFvDMiZhWe3yBpM1JSdvI1M7NBodvkK+llwAeA95Ou3/wMcDHLhh8t6c3KSom35k7gfb1ZjpmZ2UDW3UU2/kq6ccJc4HJgAnBV4cIbrbIzTbYZm5mZdYLuznynA98FLouIF/pi5ZJ2B94NHNJNmcPItywcM2ZMX4RhZmZWqYbJNyLe1ZcrljSWdCGPyyJiSjdxTAYmA4wbNy76MiYzM7MqND3UqJXyTRmuAh4GPtSOGMzMzNql8uQraTjwG2AY8I6+qtI2MzPrr1Z2nO9KkTQUuIDUc/qNEfFEles3MzPrDypNvsCPSBfW+CywkaSNCvPuzHdMMjMz62hVJ98989/v15m3Jc1fKcvMzGzAqjT5RsTYKtdnZmbWH7Wlt7OZmdlg5uRrZmZWMSdfMzOzijn5mpmZVczJ18zMrGJOvmZmZhVz8jUzM6uYk6+ZmVnFnHzNzMwq5uRrZmZWsaqv7WwlEyZMoKuri9GjRzNp0qR2h2NmZhVw8m2zrq4uZsyY0e4wzMysQq52NjMzq5iTr5mZWcWcfM3MzCrm5GtmZlYxJ18zM7OKOfmamZlVzMnXzMysYk6+ZmZmFXPyNTMzq5iTr5mZWcWcfM3MzCrm5GtmZlYxJ18zM7OKOfmamZlVzMnXzMysYh17P98dv3R2u0NoyrqznmMI8Mis5/p9zHeccHC7QzAz6wg+8zUzM6uYk6+ZmVnFKk++kraW9FNJf5W0WNLUqmMwMzNrp3a0+W4P7A3cAqzehvWbmZm1VTuqnS+PiM0jYn/g7jas38zMrK0qT74RsaTqdZqZmfUnHTvUyMwGrwkTJtDV1cXo0aOZNGlSu8MxW4GTr5l1nK6uLmbMmNHuMMwa6vdDjSQdJmmapGkzZ85sdzhmZmarrN8n34iYHBHjImLcqFGj2h2OmZnZKuv3ydfMzKzTOPmamZlVrPIOV5KGky6yAfAiYD1J++XnV0bEC1XHZGZmVqV29HbeGLigNK32fEtgeqXRmJmZVazy5BsR0wFVvV4zM7P+wm2+ZmZmFXPyNTMzq5ivcNVmS4atvdxfMzPrfE6+bfb8Nnu2OwSzpu3yw13aHUJThj09jNVYjX8//e8BEfMfjvhDu0Owirna2czMrGJOvmZmZhVz8jUzM6uYk6+ZmVnFnHzNzMwq5uRrZmZWMSdfMzOzijn5mpmZVczJ18zMrGJOvmZmZhVz8jUzM6uYr+1sZh0nhgdLWEIMj3aHYlaXk6+ZdZyFuyxsdwhm3XK1s5mZWcWcfM3MzCrm5GtmZlYxJ18zM7OKOfmamZlVzMnXzMysYk6+ZmZmFXPyNTMzq5gvsmFmZm01YcIEurq6GD16NJMmTWp3OJVw8jUzs7bq6upixowZ7Q6jUq52NjMzq5iTr5mZWcVc7Wxm1qFufPOu7Q6hKXOHDgGJuY8+OiBi3vWmG1d5GZWf+UraTtL1kl6Q9B9Jx0saUnUcZmZm7VLpma+kDYHrgHuAfYGXACeRDgK+UWUsZmZm7VJ1tfMngWm64RgAAA3dSURBVLWA90bEs8C1ktYDJkqalKeZmZl1tKqrnf8buKaUZM8lJeT+X9FvZmYtt0EEIyLYIKLdoVSm6jPflwE3FCdExCOSXsjzLq84HjMza7MPL17S7hAqV/WZ74bA03WmP5XnmZmZdbx+P9RI0mHAYfnpHEn3tjOePjISmNXuIHqiEz/S7hD6gwGxrwA4Vu2OoN0GzL7SkYN+X8EA2l+o6f21RaMZVSffp4D160zfMM9bQURMBib3ZVDtJmlaRIxrdxzWM++rgcP7amAZbPur6mrnf5LadpeStDkwPM8zMzPreFUn36uAt0tatzDt/cBcYNUvGWJmZjYAVJ18fwLMBy6WtEduz50InDzIx/h2dLV6h/G+Gji8rwaWQbW/FBWPq5K0HXAqsDOp5/P/ARMjYnGlgZiZmbVJ5cnXzMxssOvIWwpKerek30qaLWmBpBmSLpS0V4UxTJd0YouWFZIO76HM2FwuJL2pzvxv5HnTWxFTq0iaKGlgDC9oIUnvk3SDpKclzZd0n6STJW1WKBOFx1xJj0i6WNI76yxvYql87XFdte9shbj67Lso6R35PY5d9UhXWPZL8zbdoMnyb5P0q/y9D0kTWx1TXxsM+0rSEElflvT7/D5n5/f8ulbH1ZOOS76SvgdcBMwAPg7sAXyFdAnLqyS9pI3hVWEOcGCd6QfmedZmkk4CzgceBA4C9gS+B+wOnFYqfhKpiWZP0ud4AXCZpDPrLPqZXLb4OKIP3kJTBvh38aXAsUBTyRfYC3gVcD3wQl8F1VcG0b5ai/S+bid99z4MLARulrRjn0VYR7+/yEZvSNoX+Bzw0YiYUpp9Tj5jmFt5YNW6HNhP0mdr7eiSXgm8nPSDv3M7g+vPJAlYIyLm9eE63gl8AfhYRBQT6I2SJpOSbNH0iLil8PyXkn4LnCHpxoj4WWHeolLZthmE38UvRcRRsPS9DxiDbF/NBbaKiKXXlZB0PXAfcDjw0aoC6bQz388Bt9f5AAEQEZdHxH8KVbTvKM6XNEXStMLziZJmSdpJ0rRc9XezpC0lbSzpUklzJP1D0lvrrVPS0ZK6crlfSFq/NH+EpMmSHpc0T9IfJe20Ctvg18C6wFsK0w4EbiYd1RbXvbakUyXdq3R/5YcknaZ0p6lamf0lLZG0e2HaWEnPSvr2KsS5Akm75f2ym6QL8jZ7UNKnC2XG5yqxDUqv3T6/do/CtH3zfpuX98EkSasX5tf275sk3Q7MA/aXtLqkE5Wqeecr3Xf6EknDCq8dI+lcSU/mbXeNpG2beJufB/5cSrwARMTiiLiqpwXk194KfKqJ9bVLU99FAElHSbpd0jP5e3C5pK2L5ZVMlPSEpOcknQ2sV15u3i9X5e/qQ/nzcqGkqaVyr5B0RV7Wc/nzNjrP241l15l/SE0010TEQL448aDZV/k79lRp2gLgbmCz+q/qGx2TfCUNJZ3V/bbFix5O6gL/PeADwBjgHOBXpIT2XlJSu0DS8NJrP0CqvjmUdLazD6l3dy3mNUj3N94D+BLwbmAmcF3tw7US5gC/yeuuOTDHW++9DQG+Trrj1NHAW4ELagUi4gLgPOBMSetJEnAW8BBw3ErG2JPTgb8A7wGmAqdJen2edykQeV7R+4HHgd8BSDoAuBi4DXhXjvUw4Lul1w0HfkbaL3vl8l8FPkTaHm8j/Tg9Q9pWSBpB2vfbkm6TeQCwNmm/rdXoTeXE/0bg6ia3Q3euBXYsHkzkdQwtPSq/buFKfBdfTBoBsS/puzIE+KOWP1A9EjiG9F3cj3QGM6m0XpEOPl8OHEL6zh0J7FQqtzXwB2BNUrXjeGB74PK8jD8DX8zF35vfS/nz1hG8r5b+Dr+WdPZbnYjoiAewCelH+ROl6SJVr9ceAsbmsu8olZ0CTCs8n5jL7VqY9uk87ZjCtO3ytP8uTJsOPAmsU5j2IWAJ8PL8/GOkNrxtCmWGAv8CTihMC+DwHt7/0vdE+vA9CQwDXk9q0xgJnEiqxmy0jKHALnk5YwrTRwD/Ac4gfUHmA69u0X6bCMzK/++W1318Yf7qpAOS/ylMuwy4urSce4FTC/v8YeCsUplDSD8EG5X2776lcr8BTuom5m8Cs4ERhWkbkhL0Z7p53eh6n9Fuyjfc78An8vxNSu+l/NijP38X67x2CKld7jng4MK0/wA/LpW9Nq9nbH6+T37+ukKZF+XP/9TCtHPy52VYYdo2wGJgn/z8HcVl9/L9zyINn6x0u3tf9X5f5dcfT/pN27bKbd8xZ74F5bFTR5F2aO3xmV4ubwHw+8LzB/LfG+pMe1HptddGRLGT0yWkD3WtZ90ewB2k6pKh+SgU0tW+6l7jNFfpFM9s6u3DK0lfgreTznqvj4i6PYolHSTpTklzyB0P8qyX1spExJOko9xDgBNIyfEv9ZbXIkuPwiNiIXA/6Yi75jxgd0kb5ffwmhzveYXYxwDnF7cVaZ+tCbyisKwgXXmt6C5gvKQJkl5V5+xxD9KPybOFZT9H2pfjckxDujkDbcX4vnpntM+QPlvFx60tWNfKauq7KOkNkq6VNBtYROqwtA7LPoObA5uSDrqKLi49fx3QFRG3Lw0gYgZpvxTtQfouLinsv4dIB8zdXlu4XLPQXdkBZlDuK0n7kGr+vhwRld60p5OS72zS0cuLS9PPYdkP0cp4LpZvz1mQ/y69NWKkNgNIP+xFTxSfRMQLpGrhTfOkkcAbWP5DvpDU6L95g3g+Uipbr+1wPql69oOkKtFz6y1I0nuAs4E/AfvnWGpVNuX3cgOpWnc1UrVwXyrfdnJBKZ5fk977+/Lz9wOPsuzAYWT+eyXLb6uH8vTitn2qsP9qvkXqdfxpUvX3vyV9tjB/ZF5neb+9pbDs60vzdmXZZ3RMw3fevNpZwpOFaYsiYlrp8VwL1tVbTX8XJY0hHWyJdDa/S57/BMv2ea0JZrnvU53no0m1JGXlaSOBL7Pi/tuKxt+7mvJrBrpBu6+UhhedB/wkIk7pYVkt1zFHbhGxSNKfSL1FjylMf5yUNCicfNR6sw5jea2+p/DGxSe5TXgd4LE86UlgGvU7zsxvsMzLWf5AotEY2XNJ1acLSUeO9ewP3BoRxQ5NuzYo+z+ks+ku4BRSYm+LiJgj6QpSApxMOsC4IHIdEssS0mHAnXUW8VDh/xXOQiP1dj4GOEbSNqR23VMk3RsRV+fl/5pU/VxWS3afIHV8q7k3IhZK+gOpRuIbPb/Tbu0J3JFrBvqVXn4X9yK1u+8bEc/neUNJTR01Xfnvct+nOs+7gFF1QhrFsu88pP13CYX+FwU9jTmvfDxoXxqs+0rSS4ErSAfJR/awnD7RMck3OwW4VNJBEXFON+WeICWll9cmSFqH1Bnm4RbG8zZJ6xSqnt9D+rGv9ai+nvShfyQiykeGdUXEbNLRak+uJY3b+2dEPNOgzFqsmOQ/VC6UexQeQUpyzwLXSLooIi5qJuY+ci5wntIwiK1Y/uz+XlInuLERsUpn6RFxv6QvkqrdtiN1lrqetC3ujoi6QzC6qcI6Bfi1pI/E8sOEyE0Ie+YE35CkQ0ht+f35BsvNfhfXIvWDWFSYdgDL/zb9m/RjvS/Ld1Z7b2lZtwPHSnp9RNwGIOlFwI6kTjs115M67dxROGArq1ubFRHT6pQd6AbVvpK0KXANqW/NB6JNlzbuqOQbEZdJOgWYIuktpLPEWcBGLBs/OScilki6DPi8pIdJ1ZxH0fqxbHOBKySdQKpqPgG4JCLuyfPPJp1VTVW6GtaDOdbXk9pDvreyK46IRaQvRneuJfUk/jqpbXBv0oUelsoHJWcC50XEhXnaT4EfS7opIupVHVXhSlJ700+Bh2pfYEjDPiQdRRqjuB6pTXcBKUm/G9gvNwHUJekSUtvTnaR9uB/pu3JTLnIyqeflDZJ+SEr0m5Cqlm+OiHo9y2uxXS7pZNI43V1IbWNzSLfa/CSpLav4ozVW0htIHc9eTPpROwA4MyLO7mkjtUuz30XSdh4CnCXpDNIP7RdZvllnsaRJwIlKV0P7PanJYenBc3YlqZngfElfJe27Y0lncMWmo4mkXu1XKF2sZBapGv9twJSImEo6gAP4hKRzgRci4m+N3q+kLVh2pjUM2E7SfsDz0cTwsXYaTPtKaTTCVaRazsOBVxXO7OdHRL2asr5RZe+uqh6kM8xrSVUWC0m97y5i+d7Im5B++J4lne0eRv3ezrNKy96NdPb6itL05Xqmkn5ET8rLeBx4njTcZ4PS69YHvk86YlxAaru8GNil0bIbvOex1OnBXSqzXG9n0hfpRFJNwLN5G+1UXA4puT3G8j171yEdKFzUgn21dBt3s22nAhfWee3Pc/nvNlj2f5O+/M/n93cXqT13aKP9m6d/iVQ78QypGvlWVuwRvRlpyNXjpNqD6Tme7Zt83+8jDYt6Ju/3+/K+GF3a77XHvPwZuRh4Z3fbsT89aO67eBDpLGQucEv+DE4HTiyUEamaf2beJ78gNX0s18sV2IJ08DKPZd/r3wKXluJ6GXBhjmsuqdPkT4EXF8oclZexiG5GCeSy40v7q/bo9nX96TEY9hXLfifbvq98YwUz61hK408fJA1DO7bd8Vhjg21fdVS1s5kNbpI+Saq2vJ/UeecLwBrUGRVg7TXY95WTr5l1knmkoSlbkKoSbyNdaKSVHSmtNQb1vnK1s5mZWcU66SIbZmZmA4KTr5mZWcWcfM3MzCrm5GtmZlYxJ18zM7OKOfmamZlV7P8B71rCTqlVXboAAAAASUVORK5CYII=\n",
            "text/plain": [
              "<Figure size 554.4x360 with 1 Axes>"
            ]
          },
          "metadata": {
            "needs_background": "light"
          }
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "0glyX4jcEiVt"
      },
      "source": [
        ""
      ],
      "execution_count": 42,
      "outputs": []
    }
  ]
}